def test_never_lowercase(self): # Our default tokenizer doesn't handle lowercasing. tokenizer = WordTokenizer() # 2 15 10 11 6 sentence = "the laziest fox" tokens = tokenizer.tokenize(sentence) tokens.append(Token("[PAD]")) # have to do this b/c tokenizer splits it in three vocab = Vocabulary() vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt' token_indexer = PretrainedBertIndexer(str(vocab_path), do_lowercase=True) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert") # PAD should get recognized and not lowercased # [PAD] assert indexed_tokens["bert"] == [16, 2, 15, 10, 11, 6, 0, 17] # Unless we manually override the never lowercases token_indexer = PretrainedBertIndexer(str(vocab_path), do_lowercase=True, never_lowercase=()) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert") # now PAD should get lowercased and be UNK # [UNK] assert indexed_tokens["bert"] == [16, 2, 15, 10, 11, 6, 1, 17]
def __init__(self, labels, domain_utils): self._tokenizer = BertPreTokenizer() # should align with reader's tokenizer self._token_indexers = PretrainedBertIndexer( pretrained_model='bert-base-uncased') # should align with reader's tokenizer self._is_cude = torch.cuda.device_count() > 0 self._domain_utils = domain_utils self._set_labels_wordpieces(labels)
def test_sliding_window(self): tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter()) sentence = "the quickest quick brown [SEP] jumped over the lazy dog" tokens = tokenizer.tokenize(sentence) vocab = Vocabulary() vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt' token_indexer = PretrainedBertIndexer(str(vocab_path), truncate_long_sequences=False, use_starting_offsets=False, max_pieces=10) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert") # 16 = [CLS], 17 = [SEP] # 1 full window + 1 half window with start/end tokens # [CLS] the quick est quick brown [SEP] jumped over [SEP] assert indexed_tokens["bert"] == [16, 2, 3, 4, 3, 5, 17, 8, 9, 17, # [CLS] brown [SEP] jumped over the lazy dog [SEP] 16, 5, 17, 8, 9, 2, 14, 12, 17] assert indexed_tokens["bert-offsets"] == [1, 3, 4, 5, 6, 7, 8, 9, 10, 11] # The extra [SEP]s shouldn't pollute the token-type-ids # [CLS] the quick est quick brown [SEP] jumped over [SEP] assert indexed_tokens["bert-type-ids"] == [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, # [CLS] brown [SEP] jumped over the lazy dog [SEP] 0, 0, 0, 1, 1, 1, 1, 1, 1]
def batch_to_ids(stncs, tgt_flag=False): """ convert list of text into ids that elmo accepts :param stncs: [['I', 'Like', 'you'],['Yes'] ] :param tgt_flag: indicates if the inputs is a target sentences, if it is, use only the previous words as context, and neglect last word :return ids: indices to feed into elmo """ tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter()) tokens = tokenizer.tokenize(stncs) vocab = Vocabulary() vocab_path = "" token_indexer = PretrainedBertIndexer(str(vocab_path)) ids = token_indexer.tokens_to_indices(tokens, vocab, "bert") if tgt_flag: ids = ids[:, :-1, :] # neglect the last word b_size, _len, dim = ids.shape expand_ids = torch.zeros(b_size * _len, _len, dim, dtype=torch.long) for i in range(1, _len + 1): expand_ids[b_size * (i - 1):b_size * i, :i, :] = ids[:, :i, :] return expand_ids return ids
def test_indexes_empty_sequence(self): vocab = Vocabulary() vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt' token_indexer = PretrainedBertIndexer(str(vocab_path)) indexed_tokens = token_indexer.tokens_to_indices([], vocab, "bert") assert indexed_tokens == { 'bert': [16, 17], # [CLS], [SEP] 'bert-offsets': [], # no tokens => no offsets 'bert-type-ids': [0, 0], # just 0s for start and end 'mask': [] # no tokens => no mask }
def __init__(self): self._tokenizer = BertPreTokenizer( ) # should align with reader's tokenizer self._token_indexers = PretrainedBertIndexer( pretrained_model='bert-base-uncased' ) # should align with reader's tokenizer self._synonyms = {'arguments': {}, 'predicates': {}} # self._load_argumets() # self._enrich_synonyms() # self._enrich_synonyms_by_hand() self._parser = pyparsing.nestedExpr( '(', ')', ignoreExpr=pyparsing.dblQuotedString)
def test_indexes_empty_sequence(self): vocab = Vocabulary() vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt" token_indexer = PretrainedBertIndexer(str(vocab_path)) indexed_tokens = token_indexer.tokens_to_indices([], vocab) assert indexed_tokens == { "input_ids": [16, 17], # [CLS], [SEP] "offsets": [], # no tokens => no offsets "token_type_ids": [0, 0], # just 0s for start and end "mask": [], # no tokens => no mask }
def __init__(self, max_span_width: int, token_indexers: Dict[str, TokenIndexer] = None, lazy: bool = False) -> None: super().__init__(lazy) self._max_span_width = max_span_width self._token_indexers = { "tokens": PretrainedBertIndexer("bert-base-cased", do_lowercase=False) } self.token_indexer = PretrainedBertIndexer("bert-base-cased", do_lowercase=False)
def test_read_from_file(self, lazy): reader = MRPCReader(tokenizer=WordTokenizer(word_splitter=JustSpacesWordSplitter()), token_indexers={"bert": PretrainedBertIndexer(pretrained_model=self.BERT_VOCAB_PATH)}, lazy=lazy, skip_label_indexing=False, mode='merge') instances = reader.read( str(self.FIXTURES_ROOT / 'mrpc_dev.tsv')) instances = ensure_list(instances) instance1 = {"tokens": "He said the foodservice pie business doesn 't fit the company 's long-term growth strategy .".split() + ["[SEP]"] + "\" The foodservice pie business does not fit our long-term growth strategy .".split(), "label": '1'} instance2 = {"tokens": "Magnarelli said Racicot hated the Iraqi regime and looked forward to using his long years of training in the war .".split() + ["[SEP]"] + "His wife said he was \" 100 percent behind George Bush \" and looked forward to using his years of training in the war .".split(), "label": '0'} instance3 = {"tokens": "The dollar was at 116.92 yen against the yen , flat on the session , and at 1.2891 against the Swiss franc , also flat .".split() + ["[SEP]"] + "The dollar was at 116.78 yen JPY = , virtually flat on the session , and at 1.2871 against the Swiss franc CHF = , down 0.1 percent .".split(), "label": '0'} for instance, expected_instance in zip(instances, [instance1, instance2, instance3]): fields = instance.fields assert [ t.text for t in fields["tokens"].tokens] == expected_instance["tokens"] assert fields["label"].label == expected_instance["label"]
def test_sliding_window_with_batch(self): tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter()) sentence = "the quickest quick brown fox jumped over the lazy dog" tokens = tokenizer.tokenize(sentence) vocab = Vocabulary() vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt' token_indexer = PretrainedBertIndexer(str(vocab_path), truncate_long_sequences=False, max_pieces=8) config_path = self.FIXTURES_ROOT / 'bert' / 'config.json' config = BertConfig(str(config_path)) bert_model = BertModel(config) token_embedder = BertEmbedder(bert_model, max_pieces=8) instance = Instance({"tokens": TextField(tokens, {"bert": token_indexer})}) instance2 = Instance({"tokens": TextField(tokens + tokens + tokens, {"bert": token_indexer})}) batch = Batch([instance, instance2]) batch.index_instances(vocab) padding_lengths = batch.get_padding_lengths() tensor_dict = batch.as_tensor_dict(padding_lengths) tokens = tensor_dict["tokens"] bert_vectors = token_embedder(tokens["bert"], offsets=tokens["bert-offsets"]) assert bert_vectors is not None
def test_read(self, lazy): reader = GLUESST2DatasetReader( tokenizer=WordTokenizer(word_splitter=BertBasicWordSplitter()), token_indexers={'bert': PretrainedBertIndexer( pretrained_model=self.BERT_VOCAB_PATH)}, skip_label_indexing=False ) instances = reader.read( str(self.FIXTURES_ROOT / 'dev.tsv')) instances = ensure_list(instances) example = instances[0] tokens = [t.text for t in example.fields['tokens']] label = example.fields['label'].label print(label) print(tokens) batch = Batch(instances) vocab = Vocabulary.from_instances(instances) batch.index_instances(vocab) padding_lengths = batch.get_padding_lengths() tensor_dict = batch.as_tensor_dict(padding_lengths) tokens = tensor_dict["tokens"] print(tokens['mask'].tolist()[0]) print(tokens["bert"].tolist()[0]) print([vocab.get_token_from_index(i, "bert") for i in tokens["bert"].tolist()[0]]) print(len(tokens['bert'][0])) print(tokens["bert-offsets"].tolist()[0]) print(tokens['bert-type-ids'].tolist()[0])
def test_squad_with_unwordpieceable_passage(self): tokenizer = SpacyTokenizer() token_indexer = PretrainedBertIndexer("bert-base-uncased") passage1 = ( "There were four major HDTV systems tested by SMPTE in the late 1970s, " "and in 1979 an SMPTE study group released A Study of High Definition Television Systems:" ) question1 = "Who released A Study of High Definition Television Systems?" passage2 = ( "Broca, being what today would be called a neurosurgeon, " "had taken an interest in the pathology of speech. He wanted " "to localize the difference between man and the other animals, " "which appeared to reside in speech. He discovered the speech " "center of the human brain, today called Broca's area after him. " "His interest was mainly in Biological anthropology, but a German " "philosopher specializing in psychology, Theodor Waitz, took up the " "theme of general and social anthropology in his six-volume work, " "entitled Die Anthropologie der Naturvölker, 1859–1864. The title was " """soon translated as "The Anthropology of Primitive Peoples". """ "The last two volumes were published posthumously.") question2 = "What did Broca discover in the human brain?" from allennlp.data.dataset_readers.reading_comprehension.util import ( make_reading_comprehension_instance, ) instance1 = make_reading_comprehension_instance( tokenizer.tokenize(question1), tokenizer.tokenize(passage1), {"bert": token_indexer}, passage1, ) instance2 = make_reading_comprehension_instance( tokenizer.tokenize(question2), tokenizer.tokenize(passage2), {"bert": token_indexer}, passage2, ) vocab = Vocabulary() batch = Batch([instance1, instance2]) batch.index_instances(vocab) padding_lengths = batch.get_padding_lengths() tensor_dict = batch.as_tensor_dict(padding_lengths) qtokens = tensor_dict["question"] ptokens = tensor_dict["passage"] config = BertConfig(len(token_indexer.vocab)) model = BertModel(config) embedder = BertEmbedder(model) _ = embedder(ptokens["bert"], offsets=ptokens["bert-offsets"]) _ = embedder(qtokens["bert"], offsets=qtokens["bert-offsets"])
def multitask_learning(): # load datasetreader # Save logging to a local file # Multitasking log.getLogger().addHandler(log.FileHandler(directory+"/log.log")) lr = 0.00001 batch_size = 2 epochs = 10 max_seq_len = 512 max_span_width = 30 #token_indexer = BertIndexer(pretrained_model="bert-base-uncased", max_pieces=max_seq_len, do_lowercase=True,) token_indexer = PretrainedBertIndexer("bert-base-cased", do_lowercase=False) conll_reader = ConllCorefBertReader(max_span_width = max_span_width, token_indexers = {"tokens": token_indexer}) swag_reader = SWAGDatasetReader(tokenizer=token_indexer.wordpiece_tokenizer,lazy=True, token_indexers=token_indexer) EMBEDDING_DIM = 1024 HIDDEN_DIM = 200 conll_datasets, swag_datasets = load_datasets(conll_reader, swag_reader, directory) conll_vocab = Vocabulary() swag_vocab = Vocabulary() conll_iterator = BasicIterator(batch_size=batch_size) conll_iterator.index_with(conll_vocab) swag_vocab = Vocabulary() swag_iterator = BasicIterator(batch_size=batch_size) swag_iterator.index_with(swag_vocab) from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder bert_embedder = PretrainedBertEmbedder(pretrained_model="bert-base-cased",top_layer_only=True, requires_grad=True) word_embedding = BasicTextFieldEmbedder({"tokens": bert_embedder}, allow_unmatched_keys=True) BERT_DIM = word_embedding.get_output_dim() seq2seq = PytorchSeq2SeqWrapper(torch.nn.LSTM(BERT_DIM, HIDDEN_DIM, batch_first=True, bidirectional=True)) seq2vec = PytorchSeq2VecWrapper(torch.nn.LSTM(BERT_DIM, HIDDEN_DIM, batch_first=True, bidirectional=True)) mention_feedforward = FeedForward(input_dim = 2336, num_layers = 2, hidden_dims = 150, activations = torch.nn.ReLU()) antecedent_feedforward = FeedForward(input_dim = 7776, num_layers = 2, hidden_dims = 150, activations = torch.nn.ReLU()) model1 = CoreferenceResolver(vocab=conll_vocab, text_field_embedder=word_embedding,context_layer= seq2seq, mention_feedforward=mention_feedforward,antecedent_feedforward=antecedent_feedforward , feature_size=768,max_span_width=max_span_width,spans_per_word=0.4,max_antecedents=250,lexical_dropout= 0.2) model2 = SWAGExampleModel(vocab=swag_vocab, text_field_embedder=word_embedding, phrase_encoder=seq2vec) optimizer1 = optim.Adam(model1.parameters(), lr=lr) optimizer2 = optim.Adam(model2.parameters(), lr=lr) swag_train_iterator = swag_iterator(swag_datasets[0], num_epochs=1, shuffle=True) conll_train_iterator = conll_iterator(conll_datasets[0], num_epochs=1, shuffle=True) swag_val_iterator = swag_iterator(swag_datasets[1], num_epochs=1, shuffle=True) conll_val_iterator:q = conll_iterator(conll_datasets[1], num_epochs=1, shuffle=True) task_infos = {"swag": {"model": model2, "optimizer": optimizer2, "loss": 0.0, "iterator": swag_iterator, "train_data": swag_datasets[0], "val_data": swag_datasets[1], "num_train": len(swag_datasets[0]), "num_val": len(swag_datasets[1]), "lr": lr, "score": {"accuracy":0.0}}, \ "conll": {"model": model1, "iterator": conll_iterator, "loss": 0.0, "val_data": conll_datasets[1], "train_data": conll_datasets[0], "optimizer": optimizer1, "num_train": len(conll_datasets[0]), "num_val": len(conll_datasets[1]),"lr": lr, "score": {"coref_prediction": 0.0, "coref_recall": 0.0, "coref_f1": 0.0,"mention_recall": 0.0}}} USE_GPU = 1 trainer = MultiTaskTrainer( task_infos=task_infos, num_epochs=epochs, serialization_dir=directory + "saved_models/multitask/" ) metrics = trainer.train()
class ZeroShotExtractor: def __init__(self, labels, domain_utils): self._tokenizer = BertPreTokenizer() # should align with reader's tokenizer self._token_indexers = PretrainedBertIndexer( pretrained_model='bert-base-uncased') # should align with reader's tokenizer self._is_cude = torch.cuda.device_count() > 0 self._domain_utils = domain_utils self._set_labels_wordpieces(labels) def _get_wordpieces(self, text): # should match the reader method tokens = self._tokenizer.tokenize(text) do_lowercase = True tokens_out = ( token.text.lower() if do_lowercase and token.text not in self._tokenizer.never_split else token.text for token in tokens ) wps = [ [wordpiece for wordpiece in self._token_indexers.wordpiece_tokenizer(token)] for token in tokens_out ] wps_flat = [wordpiece for token in wps for wordpiece in token] return tuple(wps_flat) def _set_labels_wordpieces(self, labels): self._num_labels = len(labels) self._labels_wordpieces = defaultdict(list) for index, label in labels.items(): if label == 'NO-LABEL' or label == 'span': continue lexicon_phrases = self._domain_utils.get_lexicon_phrase(label) for lexicon_phrase in lexicon_phrases: self._labels_wordpieces[self._get_wordpieces(lexicon_phrase)].append(index) def get_similarity_features(self, batch_tokens, batch_spans): device = 'cuda' if self._is_cude else 'cpu' similarities = torch.zeros([batch_spans.shape[0], batch_spans.shape[1], self._num_labels], dtype=torch.float32, requires_grad=False, device=device) for k, (sentence, spans) in enumerate(zip(batch_tokens, batch_spans)): sent_len = len(sentence) span_to_ind = {} for i, span in enumerate(spans): span_to_ind[tuple(span.tolist())] = i for i in range(sent_len): for j in range(i+1, i+6): if j > sent_len: break labels = self._labels_wordpieces.get(sentence[i:j]) if labels: start = i + 1 end = j for label in labels: similarities[k, span_to_ind[(start, end)], label] = 1.0 return similarities
def test_token_type_ids(self): tokenizer = WordTokenizer() sentence = "the laziest fox" tokens = tokenizer.tokenize(sentence) # 2 15 10 11 6 17 2 15 10 11 6 # the laziest fox [SEP] the laziest fox tokens = tokens + [Token("[SEP]")] + tokens # have to do this b/c tokenizer splits `[SEP]` in three vocab = Vocabulary() vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt' token_indexer = PretrainedBertIndexer(str(vocab_path)) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert") # [CLS] 2, 15, 10, 11, 6, 17, 2 15, 10, 11, 6, [SEP] assert indexed_tokens["bert-type-ids"] == [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1] #pylint: disable=bad-whitespace
def setUp(self): super().setUp() vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt" self.token_indexer = PretrainedBertIndexer(str(vocab_path)) config_path = self.FIXTURES_ROOT / "bert" / "config.json" config = BertConfig.from_json_file(str(config_path)) self.bert_model = BertModel(config) self.token_embedder = BertEmbedder(self.bert_model)
def setUp(self): super().setUp() vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt' self.token_indexer = PretrainedBertIndexer(str(vocab_path)) config_path = self.FIXTURES_ROOT / 'bert' / 'config.json' config = BertConfig(str(config_path)) self.bert_model = BertModel(config) self.token_embedder = BertEmbedder(self.bert_model)
def test_token_type_ids(self): tokenizer = SpacyTokenizer() sentence = "the laziest fox" tokens = tokenizer.tokenize(sentence) # 2 15 10 11 6 17 2 15 10 11 6 # the laziest fox [SEP] the laziest fox tokens = ( tokens + [Token("[SEP]")] + tokens ) # have to do this b/c tokenizer splits `[SEP]` in three vocab = Vocabulary() vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt" token_indexer = PretrainedBertIndexer(str(vocab_path)) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab) # [CLS] 2, 15, 10, 11, 6, 17, 2 15, 10, 11, 6, [SEP] assert indexed_tokens["token_type_ids"] == [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
def train_only_swag(): # load datasetreader # Save logging to a local file # Multitasking log.getLogger().addHandler(log.FileHandler(directory+"/log.log")) lr = 0.00001 batch_size = 2 epochs = 100 max_seq_len = 512 max_span_width = 30 #token_indexer = BertIndexer(pretrained_model="bert-base-uncased", max_pieces=max_seq_len, do_lowercase=True,) token_indexer = PretrainedBertIndexer("bert-base-cased", do_lowercase=False) swag_reader = SWAGDatasetReader(tokenizer=token_indexer.wordpiece_tokenizer,lazy=True, token_indexers=token_indexer) EMBEDDING_DIM = 1024 HIDDEN_DIM = 200 swag_datasets = load_swag(swag_reader, directory) swag_vocab = Vocabulary() swag_vocab = Vocabulary() swag_iterator = BasicIterator(batch_size=batch_size) swag_iterator.index_with(swag_vocab) from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder bert_embedder = PretrainedBertEmbedder(pretrained_model="bert-base-cased",top_layer_only=True, requires_grad=True) word_embedding = BasicTextFieldEmbedder({"tokens": bert_embedder}, allow_unmatched_keys=True) BERT_DIM = word_embedding.get_output_dim() seq2vec = PytorchSeq2VecWrapper(torch.nn.LSTM(BERT_DIM, HIDDEN_DIM, batch_first=True, bidirectional=True)) mention_feedforward = FeedForward(input_dim = 2336, num_layers = 2, hidden_dims = 150, activations = torch.nn.ReLU()) antecedent_feedforward = FeedForward(input_dim = 7776, num_layers = 2, hidden_dims = 150, activations = torch.nn.ReLU()) model = SWAGExampleModel(vocab=swag_vocab, text_field_embedder=word_embedding, phrase_encoder=seq2vec) optimizer = optim.Adam(model.parameters(), lr=lr) USE_GPU =1 val_iterator = swag_iterator(swag_datasets[1], num_epochs=1, shuffle=True) trainer = Trainer( model=model, optimizer=optimizer, iterator=swag_iterator, validation_iterator = swag_iterator, train_dataset=swag_datasets[0], validation_dataset = swag_datasets[1], validation_metric = "+accuracy", cuda_device=0 if USE_GPU else -1, serialization_dir= directory + "saved_models/current_run_model_state_swag", num_epochs=epochs, ) metrics = trainer.train() # save the model with open(directory + "saved_models/current_run_model_state", 'wb') as f: torch.save(model.state_dict(), f)
def test_do_lowercase(self): # Our default tokenizer doesn't handle lowercasing. tokenizer = WordTokenizer() # Quick is UNK because of capitalization # 2 1 5 6 8 9 2 15 10 11 14 1 sentence = "the Quick brown fox jumped over the laziest lazy elmo" tokens = tokenizer.tokenize(sentence) vocab = Vocabulary() vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt' token_indexer = PretrainedBertIndexer(str(vocab_path), do_lowercase=False) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert") # Quick should get 1 == OOV assert indexed_tokens["bert"] == [ 16, 2, 1, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17 ] # Does lowercasing by default token_indexer = PretrainedBertIndexer(str(vocab_path)) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert") # Now Quick should get indexed correctly as 3 ( == "quick") assert indexed_tokens["bert"] == [ 16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17 ]
def test_never_lowercase(self): # Our default tokenizer doesn't handle lowercasing. tokenizer = WordTokenizer() # 2 15 10 11 6 sentence = "the laziest fox" tokens = tokenizer.tokenize(sentence) tokens.append( Token("[PAD]")) # have to do this b/c tokenizer splits it in three vocab = Vocabulary() vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt' token_indexer = PretrainedBertIndexer(str(vocab_path), do_lowercase=True) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert") # PAD should get recognized and not lowercased # [PAD] assert indexed_tokens["bert"] == [16, 2, 15, 10, 11, 6, 0, 17] # Unless we manually override the never lowercases token_indexer = PretrainedBertIndexer(str(vocab_path), do_lowercase=True, never_lowercase=()) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert") # now PAD should get lowercased and be UNK # [UNK] assert indexed_tokens["bert"] == [16, 2, 15, 10, 11, 6, 1, 17]
def test_starting_ending_offsets(self): tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter()) # 2 3 5 6 8 9 2 15 10 11 14 1 sentence = "the quick brown fox jumped over the laziest lazy elmo" tokens = tokenizer.tokenize(sentence) vocab = Vocabulary() vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt' token_indexer = PretrainedBertIndexer(str(vocab_path)) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert") # 16 = [CLS], 17 = [SEP] assert indexed_tokens["bert"] == [ 16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17 ] assert indexed_tokens["bert-offsets"] == [ 1, 2, 3, 4, 5, 6, 7, 10, 11, 12 ] token_indexer = PretrainedBertIndexer(str(vocab_path), use_starting_offsets=True) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert") assert indexed_tokens["bert"] == [ 16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17 ] assert indexed_tokens["bert-offsets"] == [ 1, 2, 3, 4, 5, 6, 7, 8, 11, 12 ]
def test_truncate_window_fit_two_wordpieces(self): """ Tests if the both `use_starting_offsets` options work properly when last word in the truncated sentence consists of two wordpieces. """ tokenizer = BertPreTokenizer() sentence = "the quickest quick brown fox jumped over the quickest dog" tokens = tokenizer.tokenize(sentence) vocab = Vocabulary() vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt" token_indexer = PretrainedBertIndexer( str(vocab_path), truncate_long_sequences=True, use_starting_offsets=True, max_pieces=13 ) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab) # 16 = [CLS], 17 = [SEP] # 1 full window + 1 half window with start/end tokens assert indexed_tokens["input_ids"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 2, 3, 4, 17] assert indexed_tokens["offsets"] == [1, 2, 4, 5, 6, 7, 8, 9, 10] assert indexed_tokens["token_type_ids"] == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] token_indexer = PretrainedBertIndexer( str(vocab_path), truncate_long_sequences=True, use_starting_offsets=False, max_pieces=13 ) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab) # 16 = [CLS], 17 = [SEP] # 1 full window + 1 half window with start/end tokens assert indexed_tokens["input_ids"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 2, 3, 4, 17] assert indexed_tokens["offsets"] == [1, 3, 4, 5, 6, 7, 8, 9, 11]
def test_truncate_window(self): tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter()) sentence = "the quickest quick brown fox jumped over the lazy dog" tokens = tokenizer.tokenize(sentence) vocab = Vocabulary() vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt' token_indexer = PretrainedBertIndexer(str(vocab_path), truncate_long_sequences=True, use_starting_offsets=True, max_pieces=10) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert") # 16 = [CLS], 17 = [SEP] # 1 full window + 1 half window with start/end tokens assert indexed_tokens["bert"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 17] assert indexed_tokens["bert-offsets"] == [1, 2, 4, 5, 6, 7, 8] assert indexed_tokens["bert-type-ids"] == [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] token_indexer = PretrainedBertIndexer(str(vocab_path), truncate_long_sequences=True, use_starting_offsets=False, max_pieces=10) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert") # 16 = [CLS], 17 = [SEP] # 1 full window + 1 half window with start/end tokens assert indexed_tokens["bert"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 17] assert indexed_tokens["bert-offsets"] == [1, 3, 4, 5, 6, 7, 8]
def test_starting_ending_offsets(self): tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter()) # 2 3 5 6 8 9 2 15 10 11 14 1 sentence = "the quick brown fox jumped over the laziest lazy elmo" tokens = tokenizer.tokenize(sentence) vocab = Vocabulary() vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt' token_indexer = PretrainedBertIndexer(str(vocab_path)) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert") assert indexed_tokens["bert"] == [2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1] assert indexed_tokens["bert-offsets"] == [0, 1, 2, 3, 4, 5, 6, 9, 10, 11] token_indexer = PretrainedBertIndexer(str(vocab_path), use_starting_offsets=True) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert") assert indexed_tokens["bert"] == [2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1] assert indexed_tokens["bert-offsets"] == [0, 1, 2, 3, 4, 5, 6, 7, 10, 11]
def test_sliding_window(self): tokenizer = BertPreTokenizer() sentence = "the quickest quick brown fox jumped over the lazy dog" tokens = tokenizer.tokenize(sentence) vocab = Vocabulary() vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt" token_indexer = PretrainedBertIndexer(str(vocab_path), truncate_long_sequences=False, max_pieces=8) config_path = self.FIXTURES_ROOT / "bert" / "config.json" config = BertConfig(str(config_path)) bert_model = BertModel(config) token_embedder = BertEmbedder(bert_model, max_pieces=8) instance = Instance( {"tokens": TextField(tokens, {"bert": token_indexer})}) batch = Batch([instance]) batch.index_instances(vocab) padding_lengths = batch.get_padding_lengths() tensor_dict = batch.as_tensor_dict(padding_lengths) tokens = tensor_dict["tokens"] # 16 = [CLS], 17 = [SEP] # 1 full window + 1 half window with start/end tokens assert tokens["bert"].tolist() == [[ 16, 2, 3, 4, 3, 5, 6, 17, 16, 3, 5, 6, 8, 9, 2, 17, 16, 8, 9, 2, 14, 12, 17 ]] assert tokens["bert-offsets"].tolist() == [[ 1, 3, 4, 5, 6, 7, 8, 9, 10, 11 ]] bert_vectors = token_embedder(tokens["bert"]) assert list(bert_vectors.shape) == [1, 13, 12] # Testing without token_type_ids bert_vectors = token_embedder(tokens["bert"], offsets=tokens["bert-offsets"]) assert list(bert_vectors.shape) == [1, 10, 12] # Testing with token_type_ids bert_vectors = token_embedder(tokens["bert"], offsets=tokens["bert-offsets"], token_type_ids=tokens["bert-type-ids"]) assert list(bert_vectors.shape) == [1, 10, 12]
def test_do_lowercase(self): # Our default tokenizer doesn't handle lowercasing. tokenizer = WordTokenizer() # Quick is UNK because of capitalization # 2 1 5 6 8 9 2 15 10 11 14 1 sentence = "the Quick brown fox jumped over the laziest lazy elmo" tokens = tokenizer.tokenize(sentence) vocab = Vocabulary() vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt' token_indexer = PretrainedBertIndexer(str(vocab_path), do_lowercase=False) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert") # Quick should get 1 == OOV assert indexed_tokens["bert"] == [16, 2, 1, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17] # Does lowercasing by default token_indexer = PretrainedBertIndexer(str(vocab_path)) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert") # Now Quick should get indexed correctly as 3 ( == "quick") assert indexed_tokens["bert"] == [16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17]
def test_read(self, lazy): reader = SnliReader( tokenizer=WordTokenizer(word_splitter=BertBasicWordSplitter()), token_indexers={ 'bert': PretrainedBertIndexer(pretrained_model=self.BERT_VOCAB_PATH) }, ) instances = reader.read( str(self.FIXTURES_ROOT / 'snli_1.0_sample.jsonl')) instances = ensure_list(instances) example = instances[0] tokens = [t.text for t in example.fields['tokens'].tokens] label = example.fields['label'].label weight = example.fields['weight'].weight assert label == 'neutral' assert weight == 1 assert instances[1].fields['weight'].weight == 0.5 assert instances[2].fields['weight'].weight == 1 assert tokens == [ 'a', 'person', 'on', 'a', 'horse', 'jumps', 'over', 'a', 'broken', 'down', 'airplane', '.', '[SEP]', 'a', 'person', 'is', 'training', 'his', 'horse', 'for', 'a', 'competition', '.' ] batch = Batch(instances) vocab = Vocabulary.from_instances(instances) batch.index_instances(vocab) padding_lengths = batch.get_padding_lengths() tensor_dict = batch.as_tensor_dict(padding_lengths) tokens = tensor_dict["tokens"] print(tokens['mask'].tolist()[0]) print(tokens["bert"].tolist()[0]) print([ vocab.get_token_from_index(i, "bert") for i in tokens["bert"].tolist()[0] ]) print(len(tokens['bert'][0])) print(tokens["bert-offsets"].tolist()[0]) print(tokens['bert-type-ids'].tolist()[0])
def test_sliding_window_with_batch(self): tokenizer = BertPreTokenizer() sentence = "the quickest quick brown fox jumped over the lazy dog" tokens = tokenizer.tokenize(sentence) vocab = Vocabulary() vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt" token_indexer = PretrainedBertIndexer(str(vocab_path), truncate_long_sequences=False, max_pieces=8) config_path = self.FIXTURES_ROOT / "bert" / "config.json" config = BertConfig.from_json_file(str(config_path)) bert_model = BertModel(config) token_embedder = BertEmbedder(bert_model, max_pieces=8) instance = Instance( {"tokens": TextField(tokens, {"bert": token_indexer})}) instance2 = Instance({ "tokens": TextField(tokens + tokens + tokens, {"bert": token_indexer}) }) batch = Batch([instance, instance2]) batch.index_instances(vocab) padding_lengths = batch.get_padding_lengths() tensor_dict = batch.as_tensor_dict(padding_lengths) tokens = tensor_dict["tokens"]["bert"] # Testing without token_type_ids bert_vectors = token_embedder(tokens["input_ids"], offsets=tokens["offsets"]) assert bert_vectors is not None # Testing with token_type_ids bert_vectors = token_embedder(tokens["input_ids"], offsets=tokens["offsets"], token_type_ids=tokens["token_type_ids"]) assert bert_vectors is not None
def word_embeddings(self): words = re.split(r'\W+',self.text) Text = ' '.join(words) tokenizer=WordTokenizer(word_splitter=BertBasicWordSplitter()) tokens = tokenizer.tokenize(Text) vocab = Vocabulary() token_indexer = PretrainedBertIndexer('bert-base-uncased') instance = Instance({"tokens":TextField(tokens,{'bert':token_indexer})}) batch = Batch([instance]) batch.index_instances(vocab) padding_lenghts = batch.get_padding_lengths() tensor_dict = batch.as_tensor_dict(padding_lenghts) Tokens = tensor_dict["tokens"] model = PretrainedBertEmbedder('bert-base-uncased') bert_vectors = model(Tokens["bert"]) return(bert_vectors)
def test_truncate_window_dont_split_wordpieces(self): """ Tests if the sentence is not truncated inside of the word with 2 or more wordpieces. """ tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter()) sentence = "the quickest quick brown fox jumped over the quickest dog" tokens = tokenizer.tokenize(sentence) vocab = Vocabulary() vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt' token_indexer = PretrainedBertIndexer(str(vocab_path), truncate_long_sequences=True, use_starting_offsets=True, max_pieces=12) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert") # 16 = [CLS], 17 = [SEP] # 1 full window + 1 half window with start/end tokens assert indexed_tokens["bert"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 2, 17] # We could fit one more piece here, but we don't, not to have a cut # in the middle of the word assert indexed_tokens["bert-offsets"] == [1, 2, 4, 5, 6, 7, 8, 9] assert indexed_tokens["bert-type-ids"] == [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] token_indexer = PretrainedBertIndexer(str(vocab_path), truncate_long_sequences=True, use_starting_offsets=False, max_pieces=12) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert") # 16 = [CLS], 17 = [SEP] # 1 full window + 1 half window with start/end tokens assert indexed_tokens["bert"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 2, 17] # We could fit one more piece here, but we don't, not to have a cut # in the middle of the word assert indexed_tokens["bert-offsets"] == [1, 3, 4, 5, 6, 7, 8, 9]
def test_starting_ending_offsets(self): tokenizer = BertPreTokenizer() # 2 3 5 6 8 9 2 15 10 11 14 1 sentence = "the quick brown fox jumped over the laziest lazy elmo" tokens = tokenizer.tokenize(sentence) vocab = Vocabulary() vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt" token_indexer = PretrainedBertIndexer(str(vocab_path)) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab) # 16 = [CLS], 17 = [SEP] assert indexed_tokens["input_ids"] == [16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17] assert indexed_tokens["offsets"] == [1, 2, 3, 4, 5, 6, 7, 10, 11, 12] token_indexer = PretrainedBertIndexer(str(vocab_path), use_starting_offsets=True) indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab) assert indexed_tokens["input_ids"] == [16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17] assert indexed_tokens["offsets"] == [1, 2, 3, 4, 5, 6, 7, 8, 11, 12]
class SpanMapper(object): def __init__(self): self._tokenizer = BertPreTokenizer( ) # should align with reader's tokenizer self._token_indexers = PretrainedBertIndexer( pretrained_model='bert-base-uncased' ) # should align with reader's tokenizer self._synonyms = {'arguments': {}, 'predicates': {}} # self._load_argumets() # self._enrich_synonyms() # self._enrich_synonyms_by_hand() self._parser = pyparsing.nestedExpr( '(', ')', ignoreExpr=pyparsing.dblQuotedString) # self._parser.setParseAction(self._parse_action) def _parse_action(self, string, location, tokens) -> Tree: raise NotImplementedError def _get_wordpieces(self, text): # should match the reader method tokens = self._tokenizer.tokenize(text) do_lowercase = True tokens_out = (token.text.lower() if do_lowercase and token.text not in self._tokenizer.never_split else token.text for token in tokens) wps = [[ wordpiece for wordpiece in self._token_indexers.wordpiece_tokenizer(token) ] for token in tokens_out] wps_flat = [wordpiece for token in wps for wordpiece in token] return tuple(wps_flat) def _align_to_text(self, constant, type): spans = [] realizations = self._synonyms[type][constant] num_tokens_text = len(self._tokens) for realization in realizations: num_tokens_real = len(realization) for begin in range(num_tokens_text - num_tokens_real + 1): end = begin + num_tokens_real if self._tokens[begin:end] == realization: span = Span(height=1, span=(begin, end), content=self._tokens[begin:end], constant=constant) spans.append(span) # assert len(spans) == 1 return spans def _align_filter_to_text(self, constants, type, hint_span=None): spans_per_constant = [] for i, constant in enumerate(constants): constant_spans = self._align_to_text( constant, type) # find spans in text for a particular constant # if len(constants) == 1 and hint_span != None: # # if i == 0 and hint_span != None: # constant_spans_ = self._choose_span_by_hint(constant_spans, hint_span) # if constant_spans_ != constant_spans: # print('here') # constant_spans = constant_spans_ # if len(constants) == 2 and hint_span != None and i == 1 and len(constant_spans) > 1: # constant_spans__ = [] # for j, span in enumerate(constant_spans): # # temporary bad way to disambiguate spans # if span.span[0] < 4 and hint_span.span[0] <= 6: # constant_spans__.append(span) # if len(constant_spans__) > 0: # print('here') # risky # constant_spans = constant_spans__ # assert len(constant_spans) >= 1 spans_per_constant.append(constant_spans) contiguous_chains = self._find_contiguous_chain(spans_per_constant) # if len(contiguous_chains) > 1 and hint_span: # contiguous_chains_ = self._filter_chains_with_hint(contiguous_chains, hint_span) # if contiguous_chains_ != contiguous_chains: # print('here') # risky # contiguous_chains = contiguous_chains_ if len(contiguous_chains) == 2: if not self._possible_chains: next_option = tuple( [span.constant for span in contiguous_chains[1]]) self._possible_chains = contiguous_chains[1] contiguous_chains = [contiguous_chains[0]] else: contiguous_chains = [self._possible_chains] self._possible_chains = None if len(contiguous_chains) != 1: print('here') assert len(contiguous_chains) == 1 contiguous_chain = contiguous_chains[0] return self._contiguous_chain_to_subtree(contiguous_chain) def _filter_chains_with_hint(self, contiguous_chains, hint_span): # min delta from span edges. Sets a high values if span is within the hint_span delta_from_hint = [ min(abs(chain[0].span[0] - hint_span.span[1]), abs(chain[-1].span[1] - hint_span.span[0])) if not self._is_span_contained(chain[0], hint_span) else 10000 for chain in contiguous_chains ] min_value = min(delta_from_hint) min_chains = [] for i, chain in enumerate(contiguous_chains): if delta_from_hint[i] == min_value: min_chains.append(chain) # if min_value == 0 and len(min_spans) > 1: # take the last span if all min_values are 0 # min_spans = [sorted(spans, key=lambda span: span.span[0], reverse=True)[0]] return min_chains def _filter_sub_chains(self, contiguous_chains): """Filter chains that are actually a part of a larger chain""" full_contiguous_chains = [] tokens = self._tokens for chain in contiguous_chains: start = chain[0].span[0] end = chain[-1].span[1] if start > 0: # there is at least one token before 'start' if tuple(tokens[start - 1:start]) in self._filter_args: continue if start > 1: # there are at least two tokens before 'start' if tuple(tokens[start - 2:start]) in self._filter_args: continue if tuple(self._tokens[end:end + 1]) in self._filter_args: continue if tuple(self._tokens[end:end + 2]) in self._filter_args: continue full_contiguous_chains.append(chain) return full_contiguous_chains def _find_contiguous_chain(self, spans_per_constant: List[List[Span]]): contiguous_chains = [] combinations = list(product(*spans_per_constant)) for comb in combinations: if all([ s.span[1] == comb[i + 1].span[0] for i, s in enumerate(comb[:-1]) ]): # check if contiguous if not any([ self._is_span_contained(sub_span, span) for sub_span in comb for span in self._decided_spans ]): # check if span wasn't decided before contiguous_chains.append(comb) else: print('here') if len(contiguous_chains) > 1: contiguous_chains_ = self._filter_sub_chains(contiguous_chains) if contiguous_chains != contiguous_chains_: print('here') contiguous_chains = contiguous_chains_ return contiguous_chains def _contiguous_chain_to_subtree(self, contiguous_chain: List[Span]): self._decided_spans += contiguous_chain tree = Tree() stack = [] for i in range(len(contiguous_chain) - 1): # make parent start = contiguous_chain[i].span[0] end = contiguous_chain[-1].span[1] span = Span(height=len(contiguous_chain) - 1, span=(start, end), content=self._tokens[start:end], constant=None) parent = stack[-1] if len(stack) > 0 else None identifier = '{}-{}'.format(start, end) tree.create_node(identifier=identifier, data=span, parent=parent) stack.append(identifier) # make left child span_lc = contiguous_chain[i] identifier_lc = '{}-{}'.format(span_lc.span[0], span_lc.span[1]) tree.create_node(identifier=identifier_lc, data=span_lc, parent=identifier) # make last right child span_rc = contiguous_chain[-1] identifier_rc = '{}-{}'.format(span_rc.span[0], span_rc.span[1]) tree.create_node(identifier=identifier_rc, data=span_rc, parent=stack[-1] if stack else None) return tree # return top most identifier def _join_trees(self, subtree_1: Tree, subtree_2: Tree): top_tree = Tree() arg_1_span = subtree_1.root.split('-') arg_2_span = subtree_2.root.split('-') start = int(arg_1_span[0]) end = int(arg_2_span[1]) identifier = '{}-{}'.format(start, end) span = Span(height=100, span=(start, end), content=self._tokens[start:end], constant=None) top_tree.create_node(identifier=identifier, data=span) top_tree.paste(nid=identifier, new_tree=subtree_1) top_tree.paste(nid=identifier, new_tree=subtree_2) return top_tree def _combine_trees(self, subtree_1: Tree, subtree_2: Tree): subtree_2.paste(nid=subtree_2.root, new_tree=subtree_1) return subtree_2 def _join_binary_predicate_tree(self, predicate: Tree, arg_1: Tree, arg_2: Tree, allow_arg_switch: bool = True): predicate_start = int(predicate.root.split('-')[0]) arg_1_start = int(arg_1.root.split('-')[0]) arg_2_start = int(arg_2.root.split('-')[0]) # if arg_2 is not in the middle if not (arg_2_start > predicate_start and arg_2_start < arg_1_start ) and not (arg_2_start < predicate_start and arg_2_start > arg_1_start): if predicate_start < arg_1_start: # predicate is left to arg_1 join_1 = self._join_trees(predicate, arg_1) else: join_1 = self._join_trees(arg_1, predicate) if predicate_start < arg_2_start: # predicate is left to arg_2 join_2 = self._join_trees(join_1, arg_2) else: join_2 = self._join_trees(arg_2, join_1) return join_2 else: if not allow_arg_switch: # in this case we do not allow arg_2 to be in a span with the predicate raise Exception( 'Argument switch is not allowed={}'.format(predicate)) if predicate_start < arg_1_start: # predicate is left to arg_1, and arg_2 is in the middle join_1 = self._join_trees(predicate, arg_2) join_2 = self._join_trees(join_1, arg_1) else: join_1 = self._join_trees(arg_2, predicate) join_2 = self._join_trees(arg_1, join_1) return join_2 def _join_unary_predicate_tree(self, predicate: Tree, arg: Tree): predicate_start = int(predicate.root.split('-')[0]) arg_start = int(arg.root.split('-')[0]) if predicate_start < arg_start: # predicate is left to arg_1 join_tree = self._join_trees(predicate, arg) else: join_tree = self._join_trees(arg, predicate) return join_tree def _filter_contained_spans(self, spans): spans_ = sorted(spans, key=lambda span: span.span[1] - span.span[0], reverse=True) # sort according to span length if spans_ != spans: print('here') spans = spans_ filtered_spans = [] for span in spans: for broad_span in filtered_spans: if self._is_span_contained( span, broad_span): # span contained in broad_span break else: filtered_spans.append(span) return filtered_spans def _is_span_contained(self, span_1: Span, span_2: Span): return set(range(span_1.span[0], span_1.span[1])).issubset( set(range(span_2.span[0], span_2.span[1]))) def _is_spans_intersect(self, span_1: Span, span_2: Span): return len( set(range(span_1.span[0], span_1.span[1])).intersection( set(range(span_2.span[0], span_2.span[1])))) > 0 def _choose_span_by_hint(self, spans, hint_span): """Chooses that span that is closest to the hint span. The hint span is the one the selected span should be close to.""" # min delta from span edges. Sets a high values if span is within the hint_span delta_from_hint = [ min(abs(span.span[0] - hint_span.span[1]), abs(span.span[1] - hint_span.span[0])) if not self._is_span_contained(span, hint_span) else 10000 for span in spans ] min_value = min(delta_from_hint) min_spans = [] for i, span in enumerate(spans): if delta_from_hint[i] == min_value: min_spans.append(span) if min_value == 0 and len( min_spans) > 1: # take the last span if all min_values are 0 min_spans = [ sorted(spans, key=lambda span: span.span[0], reverse=True)[0] ] # risky return min_spans def _get_tree_from_constant(self, constant, type, hint_span=None, constant_prefix=None): spans = self._align_to_text(constant, type) spans_ = self._filter_contained_spans(spans) if spans != spans_: print('here') spans = spans_ if len(spans) != 1: print('here') # if len(spans) > 1 and hint_span: # spans__ = self._choose_span_by_hint(spans, hint_span) # if spans != spans__: # print('here') # spans = spans__ spans___ = [] for span in spans: if not any([ self._is_span_contained(span, decided_span) for decided_span in self._decided_spans ]): spans___.append(span) if spans___ != spans: print('here') spans = spans___ if len(spans) == 2: if not constant in self._possible_constants: self._possible_constants[constant] = spans[1] spans = [spans[0]] else: spans = [self._possible_constants[constant]] del self._possible_constants[constant] if len(spans) != 1: print('spans for {} are {}'.format(constant, spans)) assert len(spans) == 1 span = spans[0] if constant_prefix: span.constant = '{}#{}'.format(constant_prefix, span.constant) self._decided_spans.append(span) identifier = '{}-{}'.format(span.span[0], span.span[1]) constant_tree = Tree() constant_tree.create_node(identifier=identifier, data=span) return constant_tree def is_valid_tree(self, parse_tree: Tree): is_violateing = [ self._is_violating_node(node, parse_tree) for node in parse_tree.expand_tree() ] if any(is_violateing): print('here') return not any(is_violateing) def is_projective_tree(self, parse_tree: Tree): is_violateing = [ len(parse_tree.children(node)) > 2 for node in parse_tree.expand_tree() ] if any(is_violateing): print('here') return not any(is_violateing) def _is_violating_node(self, node, parse_tree): """Checks id a node is violated - if its child's span is not contained in its parent span, or intersect another child.""" node_span = parse_tree.get_node(node).data for child in parse_tree.children(node): child_span = child.data if not self._is_span_contained( child_span, node_span): # not contained in parent's span print('node {} is not contained in parent {}'.format( child_span.to_string, node_span.to_string)) return True for child_other in parse_tree.children(node): child_other_span = child_other.data if child != child_other and self._is_spans_intersect( child_span, child_other_span): # intersects another span print('node {} intersectes node {}'.format( child_span.to_string, child_other_span.to_string)) return True return False def map_prog_to_tree(self, question, program): program = re.sub(r'(\w+) \(', r'( \1', program) self._program = program.replace(',', '') self._tokens = self._get_wordpieces(question) self._tree = Tree() self._decided_spans = [] parse_result = self._parser.parseString(self._program)[0] return parse_result # # print the program tree # executor.parser.setParseAction(_parse_action_tree) # tree_parse = executor.parser.parseString(program)[0] # print('parse_tree=') # pprint(tree_parse) # def pprint(node, tab=""): # if isinstance(node, str): # print(tab + u"┗━ " + str(node)) # return # print(tab + u"┗━ " + str(node.value)) # for child in node.children: # pprint(child, tab + " ") def _parse_action_tree(string, location, tokens): from collections import namedtuple Node = namedtuple("Node", ["value", "children"]) node = Node(value=tokens[0][0], children=tokens[0][1:]) return node def _get_aligned_span(self, subtree: Tree): """Gets hint span for aligning ambiguous constants (e.g., 'left' appears twice)""" return subtree.get_node(subtree.root).data def _get_first_argument_to_join(self, predicate_tree, arg1_tree, arg2_tree): predicate_span_start = int(predicate_tree.root.split('-')[0]) arg1_span_start = int(arg1_tree.root.split('-')[0]) arg2_span_start = int(arg2_tree.root.split('-')[0]) if predicate_span_start < arg2_span_start < arg1_span_start or predicate_span_start > arg2_span_start > arg1_span_start: return arg2_tree, arg1_tree else: return arg1_tree, arg2_tree def _get_details(self, child, span_labels): data = child.data start = data.span[0] end = data.span[1] - 1 span = (data.span[0], data.span[1] - 1) type = data.constant if data.constant else 'span' is_span = type == 'span' if not is_span: span_labels.append({'span': span, 'type': type}) return start, end, is_span def _adjust_end(self, start, end, adjusted_end, is_span, span_labels): if end < adjusted_end - 1: span_labels.append({ 'span': (start, adjusted_end - 1), 'type': 'span' }) else: if is_span: span_labels.append({'span': (start, end), 'type': 'span'}) def _inner_write(self, span_labels, children, end, parse_tree): children.sort(key=lambda c: c.data.span[0]) start_1, end_1, is_span_1 = self._get_details(children[0], span_labels) start_2, end_2, is_span_2 = self._get_details(children[1], span_labels) if len(children) > 2: start_3, end_3, is_span_3 = self._get_details( children[2], span_labels) self._adjust_end(start_1, end_1, start_2, is_span_1, span_labels) if len(children) > 2: self._adjust_end(start_2, end_2, start_3, is_span_2, span_labels) self._adjust_end(start_3, end_3, end + 1, is_span_3, span_labels) # if end_2 < start_3 - 1: # span_labels.append({'span': (start_2, start_3 - 1), 'type': 'span'}) # if end_3 < end: # span_labels.append({'span': (start_3, end), 'type': 'span'}) else: self._adjust_end(start_2, end_2, end + 1, is_span_2, span_labels) # if end_2 < end: # span_labels.append({'span': (start_2, end), 'type': 'span'}) children_1 = parse_tree.children(children[0].identifier) if len(children_1) > 0: self._inner_write(span_labels, children_1, start_2 - 1, parse_tree) children_2 = parse_tree.children(children[1].identifier) if len(children) > 2: children_3 = parse_tree.children(children[1].identifier) if len(children_2) > 0: self._inner_write(span_labels, children_2, start_3 - 1, parse_tree) if len(children_3) > 0: self._inner_write(span_labels, children_3, end, parse_tree) else: if len(children_2) > 0: self._inner_write(span_labels, children_2, end, parse_tree) def write_to_output(self, line, parse_tree, output_file): tokens = self._get_wordpieces(line['question']) if line['question'] == "what state borders michigan ?": print() len_sent = len(tokens) span_labels = [] # type = 'span' # span_labels.append({'span': (0, len_sent-1), 'type': type}) root = parse_tree.root root_node = parse_tree.get_node(root).data # root_start, root_end = root.split('-') # root_start = int(root_start) # root_end = int(root_end) # # if root_start > 0: # start = 0 end = len_sent - 1 type = 'span' span_labels.append({'span': (start, end), 'type': type}) children = parse_tree.children(root) if len(children) == 0: s, t = (int(root_node.span[0]), int(root_node.span[1])) type = root_node.constant span_labels.append({'span': (s, t), 'type': type}) if s > 0: span_labels.append({'span': (s, end), 'type': 'span'}) else: child_1_start = children[0].data.span[0] if child_1_start > 0: span_labels.append({ 'span': (child_1_start, end), 'type': 'span' }) self._inner_write(span_labels, parse_tree.children(root), end, parse_tree) # while (len(parse_tree.children(root)) > 0): # children = parse_tree.children(root) # data_1 = children[0].data # span_1 = (data_1.span[0], data_1.span[1] - 1) # data_2 = children[1].data # span_2 = (data_2.span[0], data_2.span[1] - 1) # print() # for i, node in enumerate(parse_tree.expand_tree()): # data = parse_tree.get_node(node).data # span = (data.span[0], data.span[1]-1) # move to inclusive spans # if i==0: # left_extra = None # if span[0] > 0: # left_extra = (0, span[0]-1) # right_extra = None # if span[1] < len_sent-1: # left_right = (span[1]+1, len_sent-1) # # type = data.constant if data.constant else 'span' # span_labels.append({'span': span, 'type': type}) line['gold_spans'] = span_labels json_str = json.dumps(line) output_file.write(json_str + '\n')