Example #1
0
    def test_never_lowercase(self):
        # Our default tokenizer doesn't handle lowercasing.
        tokenizer = WordTokenizer()

        #            2 15 10 11  6
        sentence = "the laziest fox"

        tokens = tokenizer.tokenize(sentence)
        tokens.append(Token("[PAD]"))  # have to do this b/c tokenizer splits it in three

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
        token_indexer = PretrainedBertIndexer(str(vocab_path), do_lowercase=True)

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        # PAD should get recognized and not lowercased      # [PAD]
        assert indexed_tokens["bert"] == [16, 2, 15, 10, 11, 6, 0, 17]

        # Unless we manually override the never lowercases
        token_indexer = PretrainedBertIndexer(str(vocab_path), do_lowercase=True, never_lowercase=())
        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        # now PAD should get lowercased and be UNK          # [UNK]
        assert indexed_tokens["bert"] == [16, 2, 15, 10, 11, 6, 1, 17]
Example #2
0
 def __init__(self, labels, domain_utils):
     self._tokenizer = BertPreTokenizer()  # should align with reader's tokenizer
     self._token_indexers = PretrainedBertIndexer(
         pretrained_model='bert-base-uncased')  # should align with reader's tokenizer
     self._is_cude = torch.cuda.device_count() > 0
     self._domain_utils = domain_utils
     self._set_labels_wordpieces(labels)
Example #3
0
    def test_sliding_window(self):
        tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter())

        sentence = "the quickest quick brown [SEP] jumped over the lazy dog"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
        token_indexer = PretrainedBertIndexer(str(vocab_path),
                                              truncate_long_sequences=False,
                                              use_starting_offsets=False,
                                              max_pieces=10)

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        # 16 = [CLS], 17 = [SEP]
        # 1 full window + 1 half window with start/end tokens
                                        # [CLS] the quick est quick brown [SEP] jumped over [SEP]
        assert indexed_tokens["bert"] == [16,   2,  3,    4,  3,    5,    17,   8,     9,   17,
                                        # [CLS] brown [SEP] jumped over the lazy dog [SEP]
                                          16,   5,    17,   8,     9,   2,  14,  12, 17]
        assert indexed_tokens["bert-offsets"] == [1, 3, 4, 5, 6, 7, 8, 9, 10, 11]

        # The extra [SEP]s shouldn't pollute the token-type-ids
                                                 # [CLS] the quick est quick brown [SEP] jumped over [SEP]
        assert indexed_tokens["bert-type-ids"] == [0,    0,  0,    0,  0,    0,    0,    1,     1,   1,
                                                 # [CLS] brown [SEP] jumped over the lazy dog [SEP]
                                                   0,    0,    0,    1,     1,   1,  1,   1,  1]
 def batch_to_ids(stncs, tgt_flag=False):
     """
     convert list of text into ids that elmo accepts
     :param stncs: [['I', 'Like', 'you'],['Yes'] ]
     :param tgt_flag: indicates if the inputs is a target sentences, if it is,
                     use only the previous words as context, and neglect last word
     :return ids: indices to feed into elmo
     """
     tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter())
     tokens = tokenizer.tokenize(stncs)
     vocab = Vocabulary()
     vocab_path = ""
     token_indexer = PretrainedBertIndexer(str(vocab_path))
     ids = token_indexer.tokens_to_indices(tokens, vocab, "bert")
     if tgt_flag:
         ids = ids[:, :-1, :]  # neglect the last word
         b_size, _len, dim = ids.shape
         expand_ids = torch.zeros(b_size * _len,
                                  _len,
                                  dim,
                                  dtype=torch.long)
         for i in range(1, _len + 1):
             expand_ids[b_size * (i - 1):b_size * i, :i, :] = ids[:, :i, :]
         return expand_ids
     return ids
Example #5
0
    def test_indexes_empty_sequence(self):
        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
        token_indexer = PretrainedBertIndexer(str(vocab_path))

        indexed_tokens = token_indexer.tokens_to_indices([], vocab, "bert")
        assert indexed_tokens == {
                'bert': [16, 17],           # [CLS], [SEP]
                'bert-offsets': [],         # no tokens => no offsets
                'bert-type-ids': [0, 0],    # just 0s for start and end
                'mask': []                  # no tokens => no mask
        }
 def __init__(self):
     self._tokenizer = BertPreTokenizer(
     )  # should align with reader's tokenizer
     self._token_indexers = PretrainedBertIndexer(
         pretrained_model='bert-base-uncased'
     )  # should align with reader's tokenizer
     self._synonyms = {'arguments': {}, 'predicates': {}}
     # self._load_argumets()
     # self._enrich_synonyms()
     # self._enrich_synonyms_by_hand()
     self._parser = pyparsing.nestedExpr(
         '(', ')', ignoreExpr=pyparsing.dblQuotedString)
    def test_indexes_empty_sequence(self):
        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt"
        token_indexer = PretrainedBertIndexer(str(vocab_path))

        indexed_tokens = token_indexer.tokens_to_indices([], vocab)
        assert indexed_tokens == {
            "input_ids": [16, 17],  # [CLS], [SEP]
            "offsets": [],  # no tokens => no offsets
            "token_type_ids": [0, 0],  # just 0s for start and end
            "mask": [],  # no tokens => no mask
        }
 def __init__(self,
              max_span_width: int,
              token_indexers: Dict[str, TokenIndexer] = None,
              lazy: bool = False) -> None:
     super().__init__(lazy)
     self._max_span_width = max_span_width
     self._token_indexers = {
         "tokens": PretrainedBertIndexer("bert-base-cased",
                                         do_lowercase=False)
     }
     self.token_indexer = PretrainedBertIndexer("bert-base-cased",
                                                do_lowercase=False)
Example #9
0
    def test_read_from_file(self, lazy):
        reader = MRPCReader(tokenizer=WordTokenizer(word_splitter=JustSpacesWordSplitter()),
                            token_indexers={"bert":
                                            PretrainedBertIndexer(pretrained_model=self.BERT_VOCAB_PATH)},
                            lazy=lazy,
                            skip_label_indexing=False,
                            mode='merge')
        instances = reader.read(
            str(self.FIXTURES_ROOT / 'mrpc_dev.tsv'))
        instances = ensure_list(instances)

        instance1 = {"tokens": "He said the foodservice pie business doesn 't fit the company 's long-term growth strategy .".split() + ["[SEP]"] +
                     "\" The foodservice pie business does not fit our long-term growth strategy .".split(),
                     "label": '1'}

        instance2 = {"tokens": "Magnarelli said Racicot hated the Iraqi regime and looked forward to using his long years of training in the war .".split() + ["[SEP]"] +
                     "His wife said he was \" 100 percent behind George Bush \" and looked forward to using his years of training in the war .".split(),
                     "label": '0'}

        instance3 = {"tokens": "The dollar was at 116.92 yen against the yen , flat on the session , and at 1.2891 against the Swiss franc , also flat .".split() + ["[SEP]"] +
                     "The dollar was at 116.78 yen JPY = , virtually flat on the session , and at 1.2871 against the Swiss franc CHF = , down 0.1 percent .".split(),
                     "label": '0'}

        for instance, expected_instance in zip(instances, [instance1, instance2, instance3]):
            fields = instance.fields
            assert [
                t.text for t in fields["tokens"].tokens] == expected_instance["tokens"]
            assert fields["label"].label == expected_instance["label"]
Example #10
0
    def test_sliding_window_with_batch(self):
        tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter())

        sentence = "the quickest quick brown fox jumped over the lazy dog"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()

        vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
        token_indexer = PretrainedBertIndexer(str(vocab_path), truncate_long_sequences=False, max_pieces=8)

        config_path = self.FIXTURES_ROOT / 'bert' / 'config.json'
        config = BertConfig(str(config_path))
        bert_model = BertModel(config)
        token_embedder = BertEmbedder(bert_model, max_pieces=8)

        instance = Instance({"tokens": TextField(tokens, {"bert": token_indexer})})
        instance2 = Instance({"tokens": TextField(tokens + tokens + tokens, {"bert": token_indexer})})

        batch = Batch([instance, instance2])
        batch.index_instances(vocab)

        padding_lengths = batch.get_padding_lengths()
        tensor_dict = batch.as_tensor_dict(padding_lengths)
        tokens = tensor_dict["tokens"]
        bert_vectors = token_embedder(tokens["bert"], offsets=tokens["bert-offsets"])
        assert bert_vectors is not None
Example #11
0
    def test_read(self, lazy):
        reader = GLUESST2DatasetReader(
            tokenizer=WordTokenizer(word_splitter=BertBasicWordSplitter()),
            token_indexers={'bert': PretrainedBertIndexer(
                pretrained_model=self.BERT_VOCAB_PATH)},
            skip_label_indexing=False
        )
        instances = reader.read(
            str(self.FIXTURES_ROOT / 'dev.tsv'))
        instances = ensure_list(instances)
        example = instances[0]
        tokens = [t.text for t in example.fields['tokens']]
        label = example.fields['label'].label
        print(label)
        print(tokens)
        batch = Batch(instances)
        vocab = Vocabulary.from_instances(instances)
        batch.index_instances(vocab)
        padding_lengths = batch.get_padding_lengths()
        tensor_dict = batch.as_tensor_dict(padding_lengths)
        tokens = tensor_dict["tokens"]

        print(tokens['mask'].tolist()[0])
        print(tokens["bert"].tolist()[0])
        print([vocab.get_token_from_index(i, "bert")
               for i in tokens["bert"].tolist()[0]])
        print(len(tokens['bert'][0]))
        print(tokens["bert-offsets"].tolist()[0])
        print(tokens['bert-type-ids'].tolist()[0])
    def test_squad_with_unwordpieceable_passage(self):

        tokenizer = SpacyTokenizer()

        token_indexer = PretrainedBertIndexer("bert-base-uncased")

        passage1 = (
            "There were four major HDTV systems tested by SMPTE in the late 1970s, "
            "and in 1979 an SMPTE study group released A Study of High Definition Television Systems:"
        )
        question1 = "Who released A Study of High Definition Television Systems?"

        passage2 = (
            "Broca, being what today would be called a neurosurgeon, "
            "had taken an interest in the pathology of speech. He wanted "
            "to localize the difference between man and the other animals, "
            "which appeared to reside in speech. He discovered the speech "
            "center of the human brain, today called Broca's area after him. "
            "His interest was mainly in Biological anthropology, but a German "
            "philosopher specializing in psychology, Theodor Waitz, took up the "
            "theme of general and social anthropology in his six-volume work, "
            "entitled Die Anthropologie der Naturvölker, 1859–1864. The title was "
            """soon translated as "The Anthropology of Primitive Peoples". """
            "The last two volumes were published posthumously.")
        question2 = "What did Broca discover in the human brain?"

        from allennlp.data.dataset_readers.reading_comprehension.util import (
            make_reading_comprehension_instance, )

        instance1 = make_reading_comprehension_instance(
            tokenizer.tokenize(question1),
            tokenizer.tokenize(passage1),
            {"bert": token_indexer},
            passage1,
        )

        instance2 = make_reading_comprehension_instance(
            tokenizer.tokenize(question2),
            tokenizer.tokenize(passage2),
            {"bert": token_indexer},
            passage2,
        )

        vocab = Vocabulary()

        batch = Batch([instance1, instance2])
        batch.index_instances(vocab)

        padding_lengths = batch.get_padding_lengths()
        tensor_dict = batch.as_tensor_dict(padding_lengths)
        qtokens = tensor_dict["question"]
        ptokens = tensor_dict["passage"]

        config = BertConfig(len(token_indexer.vocab))
        model = BertModel(config)
        embedder = BertEmbedder(model)

        _ = embedder(ptokens["bert"], offsets=ptokens["bert-offsets"])
        _ = embedder(qtokens["bert"], offsets=qtokens["bert-offsets"])
Example #13
0
def multitask_learning():
    # load datasetreader 
    # Save logging to a local file
    # Multitasking
    log.getLogger().addHandler(log.FileHandler(directory+"/log.log"))

    lr = 0.00001
    batch_size = 2
    epochs = 10 
    max_seq_len = 512
    max_span_width = 30
    #token_indexer = BertIndexer(pretrained_model="bert-base-uncased", max_pieces=max_seq_len, do_lowercase=True,)
    token_indexer = PretrainedBertIndexer("bert-base-cased", do_lowercase=False)
    conll_reader = ConllCorefBertReader(max_span_width = max_span_width, token_indexers = {"tokens": token_indexer})
    swag_reader = SWAGDatasetReader(tokenizer=token_indexer.wordpiece_tokenizer,lazy=True, token_indexers=token_indexer)
    EMBEDDING_DIM = 1024
    HIDDEN_DIM = 200
    conll_datasets, swag_datasets = load_datasets(conll_reader, swag_reader, directory)
    conll_vocab = Vocabulary()
    swag_vocab = Vocabulary()
    conll_iterator = BasicIterator(batch_size=batch_size)
    conll_iterator.index_with(conll_vocab)

    swag_vocab = Vocabulary()
    swag_iterator = BasicIterator(batch_size=batch_size)
    swag_iterator.index_with(swag_vocab)


    from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder

    bert_embedder = PretrainedBertEmbedder(pretrained_model="bert-base-cased",top_layer_only=True, requires_grad=True)

    word_embedding = BasicTextFieldEmbedder({"tokens": bert_embedder}, allow_unmatched_keys=True)
    BERT_DIM = word_embedding.get_output_dim()

    seq2seq = PytorchSeq2SeqWrapper(torch.nn.LSTM(BERT_DIM, HIDDEN_DIM, batch_first=True, bidirectional=True))
    seq2vec = PytorchSeq2VecWrapper(torch.nn.LSTM(BERT_DIM, HIDDEN_DIM, batch_first=True, bidirectional=True))
    mention_feedforward = FeedForward(input_dim = 2336, num_layers = 2, hidden_dims = 150, activations = torch.nn.ReLU())
    antecedent_feedforward = FeedForward(input_dim = 7776, num_layers = 2, hidden_dims = 150, activations = torch.nn.ReLU())
    model1 = CoreferenceResolver(vocab=conll_vocab, text_field_embedder=word_embedding,context_layer= seq2seq, mention_feedforward=mention_feedforward,antecedent_feedforward=antecedent_feedforward , feature_size=768,max_span_width=max_span_width,spans_per_word=0.4,max_antecedents=250,lexical_dropout= 0.2)

    model2 = SWAGExampleModel(vocab=swag_vocab, text_field_embedder=word_embedding, phrase_encoder=seq2vec)
    optimizer1 = optim.Adam(model1.parameters(), lr=lr)
    optimizer2 = optim.Adam(model2.parameters(), lr=lr)

    swag_train_iterator = swag_iterator(swag_datasets[0], num_epochs=1, shuffle=True)
    conll_train_iterator = conll_iterator(conll_datasets[0], num_epochs=1, shuffle=True)
    swag_val_iterator = swag_iterator(swag_datasets[1], num_epochs=1, shuffle=True)
    conll_val_iterator:q = conll_iterator(conll_datasets[1], num_epochs=1, shuffle=True)
    task_infos = {"swag": {"model": model2, "optimizer": optimizer2, "loss": 0.0, "iterator": swag_iterator, "train_data": swag_datasets[0], "val_data": swag_datasets[1], "num_train": len(swag_datasets[0]), "num_val": len(swag_datasets[1]), "lr": lr, "score": {"accuracy":0.0}}, \
                    "conll": {"model": model1, "iterator": conll_iterator, "loss": 0.0, "val_data": conll_datasets[1], "train_data": conll_datasets[0], "optimizer": optimizer1, "num_train": len(conll_datasets[0]), "num_val": len(conll_datasets[1]),"lr": lr, "score": {"coref_prediction": 0.0, "coref_recall": 0.0, "coref_f1": 0.0,"mention_recall": 0.0}}}
    USE_GPU = 1
    trainer = MultiTaskTrainer(
        task_infos=task_infos, 
        num_epochs=epochs,
        serialization_dir=directory + "saved_models/multitask/"
    ) 
    metrics = trainer.train()
Example #14
0
class ZeroShotExtractor:

    def __init__(self, labels, domain_utils):
        self._tokenizer = BertPreTokenizer()  # should align with reader's tokenizer
        self._token_indexers = PretrainedBertIndexer(
            pretrained_model='bert-base-uncased')  # should align with reader's tokenizer
        self._is_cude = torch.cuda.device_count() > 0
        self._domain_utils = domain_utils
        self._set_labels_wordpieces(labels)

    def _get_wordpieces(self, text): # should match the reader method
        tokens = self._tokenizer.tokenize(text)
        do_lowercase = True
        tokens_out = (
            token.text.lower()
            if do_lowercase and token.text not in self._tokenizer.never_split
            else token.text
            for token in tokens
        )
        wps = [
            [wordpiece for wordpiece in self._token_indexers.wordpiece_tokenizer(token)]
            for token in tokens_out
        ]
        wps_flat = [wordpiece for token in wps for wordpiece in token]
        return tuple(wps_flat)

    def _set_labels_wordpieces(self, labels):
        self._num_labels = len(labels)
        self._labels_wordpieces = defaultdict(list)
        for index, label in labels.items():
            if label == 'NO-LABEL' or label == 'span':
                continue
            lexicon_phrases = self._domain_utils.get_lexicon_phrase(label)
            for lexicon_phrase in lexicon_phrases:
                self._labels_wordpieces[self._get_wordpieces(lexicon_phrase)].append(index)

    def get_similarity_features(self, batch_tokens, batch_spans):
        device = 'cuda' if self._is_cude else 'cpu'
        similarities = torch.zeros([batch_spans.shape[0], batch_spans.shape[1], self._num_labels], dtype=torch.float32,
                                   requires_grad=False, device=device)

        for k, (sentence, spans) in enumerate(zip(batch_tokens, batch_spans)):
            sent_len = len(sentence)
            span_to_ind = {}
            for i, span in enumerate(spans):
                span_to_ind[tuple(span.tolist())] = i
            for i in range(sent_len):
                for j in range(i+1, i+6):
                    if j > sent_len:
                        break
                    labels = self._labels_wordpieces.get(sentence[i:j])
                    if labels:
                        start = i + 1
                        end = j
                        for label in labels:
                            similarities[k, span_to_ind[(start, end)], label] = 1.0
        return similarities
Example #15
0
    def test_token_type_ids(self):
        tokenizer = WordTokenizer()

        sentence = "the laziest  fox"

        tokens = tokenizer.tokenize(sentence)
        #           2   15 10 11  6   17    2   15 10 11  6
        #           the laziest   fox [SEP] the laziest   fox
        tokens = tokens + [Token("[SEP]")] + tokens  # have to do this b/c tokenizer splits `[SEP]` in three

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
        token_indexer = PretrainedBertIndexer(str(vocab_path))

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        #                                          [CLS] 2, 15, 10, 11, 6, 17, 2  15, 10, 11, 6, [SEP]
        assert indexed_tokens["bert-type-ids"] == [0,    0, 0,  0,  0,  0, 0,  1, 1,  1,  1,  1, 1]  #pylint: disable=bad-whitespace
    def setUp(self):
        super().setUp()

        vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt"
        self.token_indexer = PretrainedBertIndexer(str(vocab_path))

        config_path = self.FIXTURES_ROOT / "bert" / "config.json"
        config = BertConfig.from_json_file(str(config_path))
        self.bert_model = BertModel(config)
        self.token_embedder = BertEmbedder(self.bert_model)
    def setUp(self):
        super().setUp()

        vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
        self.token_indexer = PretrainedBertIndexer(str(vocab_path))

        config_path = self.FIXTURES_ROOT / 'bert' / 'config.json'
        config = BertConfig(str(config_path))
        self.bert_model = BertModel(config)
        self.token_embedder = BertEmbedder(self.bert_model)
Example #18
0
    def test_token_type_ids(self):
        tokenizer = SpacyTokenizer()

        sentence = "the laziest  fox"

        tokens = tokenizer.tokenize(sentence)
        #           2   15 10 11  6   17    2   15 10 11  6
        #           the laziest   fox [SEP] the laziest   fox
        tokens = (
            tokens + [Token("[SEP]")] + tokens
        )  # have to do this b/c tokenizer splits `[SEP]` in three

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt"
        token_indexer = PretrainedBertIndexer(str(vocab_path))

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab)

        #                                          [CLS] 2, 15, 10, 11, 6, 17, 2  15, 10, 11, 6, [SEP]
        assert indexed_tokens["token_type_ids"] == [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
def train_only_swag():
    # load datasetreader 
    # Save logging to a local file
    # Multitasking
    log.getLogger().addHandler(log.FileHandler(directory+"/log.log"))

    lr = 0.00001
    batch_size = 2
    epochs = 100
    max_seq_len = 512
    max_span_width = 30
    #token_indexer = BertIndexer(pretrained_model="bert-base-uncased", max_pieces=max_seq_len, do_lowercase=True,)
    token_indexer = PretrainedBertIndexer("bert-base-cased", do_lowercase=False)
    swag_reader = SWAGDatasetReader(tokenizer=token_indexer.wordpiece_tokenizer,lazy=True, token_indexers=token_indexer)
    EMBEDDING_DIM = 1024
    HIDDEN_DIM = 200
    swag_datasets = load_swag(swag_reader, directory)
    swag_vocab = Vocabulary()

    swag_vocab = Vocabulary()
    swag_iterator = BasicIterator(batch_size=batch_size)
    swag_iterator.index_with(swag_vocab)

    from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder

    bert_embedder = PretrainedBertEmbedder(pretrained_model="bert-base-cased",top_layer_only=True, requires_grad=True)

    word_embedding = BasicTextFieldEmbedder({"tokens": bert_embedder}, allow_unmatched_keys=True)
    BERT_DIM = word_embedding.get_output_dim()
    seq2vec = PytorchSeq2VecWrapper(torch.nn.LSTM(BERT_DIM, HIDDEN_DIM, batch_first=True, bidirectional=True))
    mention_feedforward = FeedForward(input_dim = 2336, num_layers = 2, hidden_dims = 150, activations = torch.nn.ReLU())
    antecedent_feedforward = FeedForward(input_dim = 7776, num_layers = 2, hidden_dims = 150, activations = torch.nn.ReLU())

    model = SWAGExampleModel(vocab=swag_vocab, text_field_embedder=word_embedding, phrase_encoder=seq2vec)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    USE_GPU =1 
    val_iterator = swag_iterator(swag_datasets[1], num_epochs=1, shuffle=True)
    trainer = Trainer(
        model=model,
        optimizer=optimizer,
        iterator=swag_iterator,
        validation_iterator = swag_iterator, 
        train_dataset=swag_datasets[0],
        validation_dataset = swag_datasets[1], 
        validation_metric = "+accuracy",
        cuda_device=0 if USE_GPU else -1,
        serialization_dir= directory + "saved_models/current_run_model_state_swag",
        num_epochs=epochs,
    )    

    metrics = trainer.train()
    # save the model
    with open(directory + "saved_models/current_run_model_state", 'wb') as f:
        torch.save(model.state_dict(), f)
Example #20
0
    def test_do_lowercase(self):
        # Our default tokenizer doesn't handle lowercasing.
        tokenizer = WordTokenizer()

        # Quick is UNK because of capitalization
        #           2   1     5     6   8      9    2  15 10 11 14   1
        sentence = "the Quick brown fox jumped over the laziest lazy elmo"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
        token_indexer = PretrainedBertIndexer(str(vocab_path),
                                              do_lowercase=False)

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        # Quick should get 1 == OOV
        assert indexed_tokens["bert"] == [
            16, 2, 1, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17
        ]

        # Does lowercasing by default
        token_indexer = PretrainedBertIndexer(str(vocab_path))
        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        # Now Quick should get indexed correctly as 3 ( == "quick")
        assert indexed_tokens["bert"] == [
            16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17
        ]
Example #21
0
    def test_never_lowercase(self):
        # Our default tokenizer doesn't handle lowercasing.
        tokenizer = WordTokenizer()

        #            2 15 10 11  6
        sentence = "the laziest fox"

        tokens = tokenizer.tokenize(sentence)
        tokens.append(
            Token("[PAD]"))  # have to do this b/c tokenizer splits it in three

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
        token_indexer = PretrainedBertIndexer(str(vocab_path),
                                              do_lowercase=True)

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        # PAD should get recognized and not lowercased      # [PAD]
        assert indexed_tokens["bert"] == [16, 2, 15, 10, 11, 6, 0, 17]

        # Unless we manually override the never lowercases
        token_indexer = PretrainedBertIndexer(str(vocab_path),
                                              do_lowercase=True,
                                              never_lowercase=())
        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        # now PAD should get lowercased and be UNK          # [UNK]
        assert indexed_tokens["bert"] == [16, 2, 15, 10, 11, 6, 1, 17]
Example #22
0
    def test_starting_ending_offsets(self):
        tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter())

        #           2   3     5     6   8      9    2  15 10 11 14   1
        sentence = "the quick brown fox jumped over the laziest lazy elmo"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
        token_indexer = PretrainedBertIndexer(str(vocab_path))

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        # 16 = [CLS], 17 = [SEP]
        assert indexed_tokens["bert"] == [
            16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17
        ]
        assert indexed_tokens["bert-offsets"] == [
            1, 2, 3, 4, 5, 6, 7, 10, 11, 12
        ]

        token_indexer = PretrainedBertIndexer(str(vocab_path),
                                              use_starting_offsets=True)

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        assert indexed_tokens["bert"] == [
            16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17
        ]
        assert indexed_tokens["bert-offsets"] == [
            1, 2, 3, 4, 5, 6, 7, 8, 11, 12
        ]
Example #23
0
    def test_truncate_window_fit_two_wordpieces(self):
        """
        Tests if the both `use_starting_offsets` options work properly when last
        word in the truncated sentence consists of two wordpieces.
        """

        tokenizer = BertPreTokenizer()

        sentence = "the quickest quick brown fox jumped over the quickest dog"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt"
        token_indexer = PretrainedBertIndexer(
            str(vocab_path), truncate_long_sequences=True, use_starting_offsets=True, max_pieces=13
        )

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab)

        # 16 = [CLS], 17 = [SEP]
        # 1 full window + 1 half window with start/end tokens
        assert indexed_tokens["input_ids"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 2, 3, 4, 17]
        assert indexed_tokens["offsets"] == [1, 2, 4, 5, 6, 7, 8, 9, 10]
        assert indexed_tokens["token_type_ids"] == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

        token_indexer = PretrainedBertIndexer(
            str(vocab_path), truncate_long_sequences=True, use_starting_offsets=False, max_pieces=13
        )

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab)

        # 16 = [CLS], 17 = [SEP]
        # 1 full window + 1 half window with start/end tokens
        assert indexed_tokens["input_ids"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 2, 3, 4, 17]
        assert indexed_tokens["offsets"] == [1, 3, 4, 5, 6, 7, 8, 9, 11]
Example #24
0
    def test_truncate_window(self):
        tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter())

        sentence = "the quickest quick brown fox jumped over the lazy dog"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
        token_indexer = PretrainedBertIndexer(str(vocab_path),
                                              truncate_long_sequences=True,
                                              use_starting_offsets=True,
                                              max_pieces=10)

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        # 16 = [CLS], 17 = [SEP]
        # 1 full window + 1 half window with start/end tokens
        assert indexed_tokens["bert"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 17]
        assert indexed_tokens["bert-offsets"] == [1, 2, 4, 5, 6, 7, 8]
        assert indexed_tokens["bert-type-ids"] == [
            0, 0, 0, 0, 0, 0, 0, 0, 0, 0
        ]

        token_indexer = PretrainedBertIndexer(str(vocab_path),
                                              truncate_long_sequences=True,
                                              use_starting_offsets=False,
                                              max_pieces=10)

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        # 16 = [CLS], 17 = [SEP]
        # 1 full window + 1 half window with start/end tokens
        assert indexed_tokens["bert"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 17]
        assert indexed_tokens["bert-offsets"] == [1, 3, 4, 5, 6, 7, 8]
Example #25
0
    def test_starting_ending_offsets(self):
        tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter())

        #           2   3     5     6   8      9    2  15 10 11 14   1
        sentence = "the quick brown fox jumped over the laziest lazy elmo"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
        token_indexer = PretrainedBertIndexer(str(vocab_path))

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        assert indexed_tokens["bert"] == [2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1]
        assert indexed_tokens["bert-offsets"] == [0, 1, 2, 3, 4, 5, 6, 9, 10, 11]

        token_indexer = PretrainedBertIndexer(str(vocab_path), use_starting_offsets=True)

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        assert indexed_tokens["bert"] == [2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1]
        assert indexed_tokens["bert-offsets"] == [0, 1, 2, 3, 4, 5, 6, 7, 10, 11]
    def test_sliding_window(self):
        tokenizer = BertPreTokenizer()

        sentence = "the quickest quick brown fox jumped over the lazy dog"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()

        vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt"
        token_indexer = PretrainedBertIndexer(str(vocab_path),
                                              truncate_long_sequences=False,
                                              max_pieces=8)

        config_path = self.FIXTURES_ROOT / "bert" / "config.json"
        config = BertConfig(str(config_path))
        bert_model = BertModel(config)
        token_embedder = BertEmbedder(bert_model, max_pieces=8)

        instance = Instance(
            {"tokens": TextField(tokens, {"bert": token_indexer})})

        batch = Batch([instance])
        batch.index_instances(vocab)

        padding_lengths = batch.get_padding_lengths()
        tensor_dict = batch.as_tensor_dict(padding_lengths)
        tokens = tensor_dict["tokens"]

        # 16 = [CLS], 17 = [SEP]
        # 1 full window + 1 half window with start/end tokens
        assert tokens["bert"].tolist() == [[
            16, 2, 3, 4, 3, 5, 6, 17, 16, 3, 5, 6, 8, 9, 2, 17, 16, 8, 9, 2,
            14, 12, 17
        ]]
        assert tokens["bert-offsets"].tolist() == [[
            1, 3, 4, 5, 6, 7, 8, 9, 10, 11
        ]]

        bert_vectors = token_embedder(tokens["bert"])
        assert list(bert_vectors.shape) == [1, 13, 12]

        # Testing without token_type_ids
        bert_vectors = token_embedder(tokens["bert"],
                                      offsets=tokens["bert-offsets"])
        assert list(bert_vectors.shape) == [1, 10, 12]

        # Testing with token_type_ids
        bert_vectors = token_embedder(tokens["bert"],
                                      offsets=tokens["bert-offsets"],
                                      token_type_ids=tokens["bert-type-ids"])
        assert list(bert_vectors.shape) == [1, 10, 12]
Example #27
0
    def test_do_lowercase(self):
        # Our default tokenizer doesn't handle lowercasing.
        tokenizer = WordTokenizer()

        # Quick is UNK because of capitalization
        #           2   1     5     6   8      9    2  15 10 11 14   1
        sentence = "the Quick brown fox jumped over the laziest lazy elmo"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
        token_indexer = PretrainedBertIndexer(str(vocab_path), do_lowercase=False)

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        # Quick should get 1 == OOV
        assert indexed_tokens["bert"] == [16, 2, 1, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17]

        # Does lowercasing by default
        token_indexer = PretrainedBertIndexer(str(vocab_path))
        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        # Now Quick should get indexed correctly as 3 ( == "quick")
        assert indexed_tokens["bert"] == [16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17]
Example #28
0
    def test_read(self, lazy):
        reader = SnliReader(
            tokenizer=WordTokenizer(word_splitter=BertBasicWordSplitter()),
            token_indexers={
                'bert':
                PretrainedBertIndexer(pretrained_model=self.BERT_VOCAB_PATH)
            },
        )

        instances = reader.read(
            str(self.FIXTURES_ROOT / 'snli_1.0_sample.jsonl'))
        instances = ensure_list(instances)
        example = instances[0]
        tokens = [t.text for t in example.fields['tokens'].tokens]
        label = example.fields['label'].label
        weight = example.fields['weight'].weight
        assert label == 'neutral'
        assert weight == 1
        assert instances[1].fields['weight'].weight == 0.5
        assert instances[2].fields['weight'].weight == 1
        assert tokens == [
            'a', 'person', 'on', 'a', 'horse', 'jumps', 'over', 'a', 'broken',
            'down', 'airplane', '.', '[SEP]', 'a', 'person', 'is', 'training',
            'his', 'horse', 'for', 'a', 'competition', '.'
        ]
        batch = Batch(instances)
        vocab = Vocabulary.from_instances(instances)
        batch.index_instances(vocab)

        padding_lengths = batch.get_padding_lengths()
        tensor_dict = batch.as_tensor_dict(padding_lengths)
        tokens = tensor_dict["tokens"]

        print(tokens['mask'].tolist()[0])
        print(tokens["bert"].tolist()[0])
        print([
            vocab.get_token_from_index(i, "bert")
            for i in tokens["bert"].tolist()[0]
        ])
        print(len(tokens['bert'][0]))
        print(tokens["bert-offsets"].tolist()[0])
        print(tokens['bert-type-ids'].tolist()[0])
    def test_sliding_window_with_batch(self):
        tokenizer = BertPreTokenizer()

        sentence = "the quickest quick brown fox jumped over the lazy dog"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()

        vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt"
        token_indexer = PretrainedBertIndexer(str(vocab_path),
                                              truncate_long_sequences=False,
                                              max_pieces=8)

        config_path = self.FIXTURES_ROOT / "bert" / "config.json"
        config = BertConfig.from_json_file(str(config_path))
        bert_model = BertModel(config)
        token_embedder = BertEmbedder(bert_model, max_pieces=8)

        instance = Instance(
            {"tokens": TextField(tokens, {"bert": token_indexer})})
        instance2 = Instance({
            "tokens":
            TextField(tokens + tokens + tokens, {"bert": token_indexer})
        })

        batch = Batch([instance, instance2])
        batch.index_instances(vocab)

        padding_lengths = batch.get_padding_lengths()
        tensor_dict = batch.as_tensor_dict(padding_lengths)
        tokens = tensor_dict["tokens"]["bert"]

        # Testing without token_type_ids
        bert_vectors = token_embedder(tokens["input_ids"],
                                      offsets=tokens["offsets"])
        assert bert_vectors is not None

        # Testing with token_type_ids
        bert_vectors = token_embedder(tokens["input_ids"],
                                      offsets=tokens["offsets"],
                                      token_type_ids=tokens["token_type_ids"])
        assert bert_vectors is not None
Example #30
0
 def word_embeddings(self):
     words = re.split(r'\W+',self.text) 
     Text = ' '.join(words)
     
     tokenizer=WordTokenizer(word_splitter=BertBasicWordSplitter())
     
     tokens = tokenizer.tokenize(Text)
     vocab = Vocabulary()
     token_indexer = PretrainedBertIndexer('bert-base-uncased')
     
     instance = Instance({"tokens":TextField(tokens,{'bert':token_indexer})})
     batch = Batch([instance])
     batch.index_instances(vocab)
     
     padding_lenghts = batch.get_padding_lengths()
     tensor_dict = batch.as_tensor_dict(padding_lenghts)
     
     Tokens = tensor_dict["tokens"]
     
     model = PretrainedBertEmbedder('bert-base-uncased')
     bert_vectors = model(Tokens["bert"])
     return(bert_vectors)
Example #31
0
    def test_truncate_window_dont_split_wordpieces(self):
        """
        Tests if the sentence is not truncated inside of the word with 2 or
        more wordpieces.
        """

        tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter())

        sentence = "the quickest quick brown fox jumped over the quickest dog"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
        token_indexer = PretrainedBertIndexer(str(vocab_path),
                                              truncate_long_sequences=True,
                                              use_starting_offsets=True,
                                              max_pieces=12)

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        # 16 = [CLS], 17 = [SEP]
        # 1 full window + 1 half window with start/end tokens
        assert indexed_tokens["bert"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 2, 17]
        # We could fit one more piece here, but we don't, not to have a cut
        # in the middle of the word
        assert indexed_tokens["bert-offsets"] == [1, 2, 4, 5, 6, 7, 8, 9]
        assert indexed_tokens["bert-type-ids"] == [
            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
        ]

        token_indexer = PretrainedBertIndexer(str(vocab_path),
                                              truncate_long_sequences=True,
                                              use_starting_offsets=False,
                                              max_pieces=12)

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        # 16 = [CLS], 17 = [SEP]
        # 1 full window + 1 half window with start/end tokens
        assert indexed_tokens["bert"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 2, 17]
        # We could fit one more piece here, but we don't, not to have a cut
        # in the middle of the word
        assert indexed_tokens["bert-offsets"] == [1, 3, 4, 5, 6, 7, 8, 9]
Example #32
0
    def test_starting_ending_offsets(self):
        tokenizer = BertPreTokenizer()

        #           2   3     5     6   8      9    2  15 10 11 14   1
        sentence = "the quick brown fox jumped over the laziest lazy elmo"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt"
        token_indexer = PretrainedBertIndexer(str(vocab_path))

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab)

        # 16 = [CLS], 17 = [SEP]
        assert indexed_tokens["input_ids"] == [16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17]
        assert indexed_tokens["offsets"] == [1, 2, 3, 4, 5, 6, 7, 10, 11, 12]

        token_indexer = PretrainedBertIndexer(str(vocab_path), use_starting_offsets=True)

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab)

        assert indexed_tokens["input_ids"] == [16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17]
        assert indexed_tokens["offsets"] == [1, 2, 3, 4, 5, 6, 7, 8, 11, 12]
class SpanMapper(object):
    def __init__(self):
        self._tokenizer = BertPreTokenizer(
        )  # should align with reader's tokenizer
        self._token_indexers = PretrainedBertIndexer(
            pretrained_model='bert-base-uncased'
        )  # should align with reader's tokenizer
        self._synonyms = {'arguments': {}, 'predicates': {}}
        # self._load_argumets()
        # self._enrich_synonyms()
        # self._enrich_synonyms_by_hand()
        self._parser = pyparsing.nestedExpr(
            '(', ')', ignoreExpr=pyparsing.dblQuotedString)
        # self._parser.setParseAction(self._parse_action)

    def _parse_action(self, string, location, tokens) -> Tree:
        raise NotImplementedError

    def _get_wordpieces(self, text):  # should match the reader method
        tokens = self._tokenizer.tokenize(text)
        do_lowercase = True
        tokens_out = (token.text.lower() if do_lowercase
                      and token.text not in self._tokenizer.never_split else
                      token.text for token in tokens)
        wps = [[
            wordpiece
            for wordpiece in self._token_indexers.wordpiece_tokenizer(token)
        ] for token in tokens_out]
        wps_flat = [wordpiece for token in wps for wordpiece in token]
        return tuple(wps_flat)

    def _align_to_text(self, constant, type):
        spans = []
        realizations = self._synonyms[type][constant]
        num_tokens_text = len(self._tokens)
        for realization in realizations:
            num_tokens_real = len(realization)
            for begin in range(num_tokens_text - num_tokens_real + 1):
                end = begin + num_tokens_real
                if self._tokens[begin:end] == realization:
                    span = Span(height=1,
                                span=(begin, end),
                                content=self._tokens[begin:end],
                                constant=constant)
                    spans.append(span)
        # assert len(spans) == 1
        return spans

    def _align_filter_to_text(self, constants, type, hint_span=None):
        spans_per_constant = []
        for i, constant in enumerate(constants):
            constant_spans = self._align_to_text(
                constant, type)  # find spans in text for a particular constant
            # if len(constants) == 1 and hint_span != None:
            # # if i == 0 and hint_span != None:
            #     constant_spans_ = self._choose_span_by_hint(constant_spans, hint_span)
            #     if constant_spans_ != constant_spans:
            #         print('here')
            #         constant_spans = constant_spans_
            # if len(constants) == 2 and hint_span != None and i == 1 and len(constant_spans) > 1:
            #     constant_spans__ = []
            #     for j, span in enumerate(constant_spans):
            #         # temporary bad way to disambiguate spans
            #         if span.span[0] < 4 and hint_span.span[0] <= 6:
            #             constant_spans__.append(span)
            #     if len(constant_spans__) > 0:
            #         print('here') # risky
            #         constant_spans = constant_spans__
            # assert len(constant_spans) >= 1
            spans_per_constant.append(constant_spans)
        contiguous_chains = self._find_contiguous_chain(spans_per_constant)
        # if len(contiguous_chains) > 1 and hint_span:
        #     contiguous_chains_ = self._filter_chains_with_hint(contiguous_chains, hint_span)
        #     if contiguous_chains_ != contiguous_chains:
        #         print('here') # risky
        #     contiguous_chains = contiguous_chains_
        if len(contiguous_chains) == 2:
            if not self._possible_chains:
                next_option = tuple(
                    [span.constant for span in contiguous_chains[1]])
                self._possible_chains = contiguous_chains[1]
                contiguous_chains = [contiguous_chains[0]]
            else:
                contiguous_chains = [self._possible_chains]
                self._possible_chains = None
        if len(contiguous_chains) != 1:
            print('here')
        assert len(contiguous_chains) == 1
        contiguous_chain = contiguous_chains[0]
        return self._contiguous_chain_to_subtree(contiguous_chain)

    def _filter_chains_with_hint(self, contiguous_chains, hint_span):
        # min delta from span edges. Sets a high values if span is within the hint_span
        delta_from_hint = [
            min(abs(chain[0].span[0] - hint_span.span[1]),
                abs(chain[-1].span[1] - hint_span.span[0]))
            if not self._is_span_contained(chain[0], hint_span) else 10000
            for chain in contiguous_chains
        ]
        min_value = min(delta_from_hint)
        min_chains = []
        for i, chain in enumerate(contiguous_chains):
            if delta_from_hint[i] == min_value:
                min_chains.append(chain)
        # if min_value == 0 and len(min_spans) > 1:  # take the last span if all min_values are 0
        #     min_spans = [sorted(spans, key=lambda span: span.span[0], reverse=True)[0]]
        return min_chains

    def _filter_sub_chains(self, contiguous_chains):
        """Filter chains that are actually a part of a larger chain"""
        full_contiguous_chains = []
        tokens = self._tokens
        for chain in contiguous_chains:
            start = chain[0].span[0]
            end = chain[-1].span[1]
            if start > 0:  # there is at least one token before 'start'
                if tuple(tokens[start - 1:start]) in self._filter_args:
                    continue
            if start > 1:  # there are at least two tokens before 'start'
                if tuple(tokens[start - 2:start]) in self._filter_args:
                    continue
            if tuple(self._tokens[end:end + 1]) in self._filter_args:
                continue
            if tuple(self._tokens[end:end + 2]) in self._filter_args:
                continue
            full_contiguous_chains.append(chain)
        return full_contiguous_chains

    def _find_contiguous_chain(self, spans_per_constant: List[List[Span]]):
        contiguous_chains = []
        combinations = list(product(*spans_per_constant))
        for comb in combinations:
            if all([
                    s.span[1] == comb[i + 1].span[0]
                    for i, s in enumerate(comb[:-1])
            ]):  # check if contiguous
                if not any([
                        self._is_span_contained(sub_span, span)
                        for sub_span in comb for span in self._decided_spans
                ]):  # check if span wasn't decided before
                    contiguous_chains.append(comb)
                else:
                    print('here')
        if len(contiguous_chains) > 1:
            contiguous_chains_ = self._filter_sub_chains(contiguous_chains)
            if contiguous_chains != contiguous_chains_:
                print('here')
            contiguous_chains = contiguous_chains_
        return contiguous_chains

    def _contiguous_chain_to_subtree(self, contiguous_chain: List[Span]):
        self._decided_spans += contiguous_chain
        tree = Tree()
        stack = []
        for i in range(len(contiguous_chain) - 1):
            # make parent
            start = contiguous_chain[i].span[0]
            end = contiguous_chain[-1].span[1]
            span = Span(height=len(contiguous_chain) - 1,
                        span=(start, end),
                        content=self._tokens[start:end],
                        constant=None)
            parent = stack[-1] if len(stack) > 0 else None
            identifier = '{}-{}'.format(start, end)
            tree.create_node(identifier=identifier, data=span, parent=parent)
            stack.append(identifier)

            # make left child
            span_lc = contiguous_chain[i]
            identifier_lc = '{}-{}'.format(span_lc.span[0], span_lc.span[1])
            tree.create_node(identifier=identifier_lc,
                             data=span_lc,
                             parent=identifier)

        # make last right child
        span_rc = contiguous_chain[-1]
        identifier_rc = '{}-{}'.format(span_rc.span[0], span_rc.span[1])
        tree.create_node(identifier=identifier_rc,
                         data=span_rc,
                         parent=stack[-1] if stack else None)

        return tree  # return top most identifier

    def _join_trees(self, subtree_1: Tree, subtree_2: Tree):

        top_tree = Tree()
        arg_1_span = subtree_1.root.split('-')
        arg_2_span = subtree_2.root.split('-')
        start = int(arg_1_span[0])
        end = int(arg_2_span[1])
        identifier = '{}-{}'.format(start, end)
        span = Span(height=100,
                    span=(start, end),
                    content=self._tokens[start:end],
                    constant=None)
        top_tree.create_node(identifier=identifier, data=span)
        top_tree.paste(nid=identifier, new_tree=subtree_1)
        top_tree.paste(nid=identifier, new_tree=subtree_2)
        return top_tree

    def _combine_trees(self, subtree_1: Tree, subtree_2: Tree):
        subtree_2.paste(nid=subtree_2.root, new_tree=subtree_1)
        return subtree_2

    def _join_binary_predicate_tree(self,
                                    predicate: Tree,
                                    arg_1: Tree,
                                    arg_2: Tree,
                                    allow_arg_switch: bool = True):
        predicate_start = int(predicate.root.split('-')[0])
        arg_1_start = int(arg_1.root.split('-')[0])
        arg_2_start = int(arg_2.root.split('-')[0])

        # if arg_2 is not in the middle
        if not (arg_2_start > predicate_start and arg_2_start < arg_1_start
                ) and not (arg_2_start < predicate_start
                           and arg_2_start > arg_1_start):
            if predicate_start < arg_1_start:  # predicate is left to arg_1
                join_1 = self._join_trees(predicate, arg_1)
            else:
                join_1 = self._join_trees(arg_1, predicate)

            if predicate_start < arg_2_start:  # predicate is left to arg_2
                join_2 = self._join_trees(join_1, arg_2)
            else:
                join_2 = self._join_trees(arg_2, join_1)
            return join_2
        else:
            if not allow_arg_switch:  # in this case we do not allow arg_2 to be in a span with the predicate
                raise Exception(
                    'Argument switch is not allowed={}'.format(predicate))
            if predicate_start < arg_1_start:  # predicate is left to arg_1, and arg_2 is in the middle
                join_1 = self._join_trees(predicate, arg_2)
                join_2 = self._join_trees(join_1, arg_1)
            else:
                join_1 = self._join_trees(arg_2, predicate)
                join_2 = self._join_trees(arg_1, join_1)
            return join_2

    def _join_unary_predicate_tree(self, predicate: Tree, arg: Tree):
        predicate_start = int(predicate.root.split('-')[0])
        arg_start = int(arg.root.split('-')[0])

        if predicate_start < arg_start:  # predicate is left to arg_1
            join_tree = self._join_trees(predicate, arg)
        else:
            join_tree = self._join_trees(arg, predicate)

        return join_tree

    def _filter_contained_spans(self, spans):
        spans_ = sorted(spans,
                        key=lambda span: span.span[1] - span.span[0],
                        reverse=True)  # sort according to span length
        if spans_ != spans:
            print('here')
        spans = spans_
        filtered_spans = []
        for span in spans:
            for broad_span in filtered_spans:
                if self._is_span_contained(
                        span, broad_span):  # span contained in broad_span
                    break
            else:
                filtered_spans.append(span)
        return filtered_spans

    def _is_span_contained(self, span_1: Span, span_2: Span):
        return set(range(span_1.span[0], span_1.span[1])).issubset(
            set(range(span_2.span[0], span_2.span[1])))

    def _is_spans_intersect(self, span_1: Span, span_2: Span):
        return len(
            set(range(span_1.span[0], span_1.span[1])).intersection(
                set(range(span_2.span[0], span_2.span[1])))) > 0

    def _choose_span_by_hint(self, spans, hint_span):
        """Chooses that span that is closest to the hint span. The hint span is the one the selected span should be close to."""

        # min delta from span edges. Sets a high values if span is within the hint_span
        delta_from_hint = [
            min(abs(span.span[0] -
                    hint_span.span[1]), abs(span.span[1] - hint_span.span[0]))
            if not self._is_span_contained(span, hint_span) else 10000
            for span in spans
        ]
        min_value = min(delta_from_hint)
        min_spans = []
        for i, span in enumerate(spans):
            if delta_from_hint[i] == min_value:
                min_spans.append(span)
        if min_value == 0 and len(
                min_spans) > 1:  # take the last span if all min_values are 0
            min_spans = [
                sorted(spans, key=lambda span: span.span[0], reverse=True)[0]
            ]  # risky
        return min_spans

    def _get_tree_from_constant(self,
                                constant,
                                type,
                                hint_span=None,
                                constant_prefix=None):
        spans = self._align_to_text(constant, type)
        spans_ = self._filter_contained_spans(spans)
        if spans != spans_:
            print('here')
        spans = spans_
        if len(spans) != 1:
            print('here')
        # if len(spans) > 1 and hint_span:
        #     spans__ = self._choose_span_by_hint(spans, hint_span)
        #     if spans != spans__:
        #         print('here')
        #     spans = spans__
        spans___ = []
        for span in spans:
            if not any([
                    self._is_span_contained(span, decided_span)
                    for decided_span in self._decided_spans
            ]):
                spans___.append(span)
        if spans___ != spans:
            print('here')
        spans = spans___
        if len(spans) == 2:
            if not constant in self._possible_constants:
                self._possible_constants[constant] = spans[1]
                spans = [spans[0]]
            else:
                spans = [self._possible_constants[constant]]
                del self._possible_constants[constant]
        if len(spans) != 1:
            print('spans for {} are {}'.format(constant, spans))
        assert len(spans) == 1
        span = spans[0]
        if constant_prefix:
            span.constant = '{}#{}'.format(constant_prefix, span.constant)
        self._decided_spans.append(span)
        identifier = '{}-{}'.format(span.span[0], span.span[1])
        constant_tree = Tree()
        constant_tree.create_node(identifier=identifier, data=span)
        return constant_tree

    def is_valid_tree(self, parse_tree: Tree):
        is_violateing = [
            self._is_violating_node(node, parse_tree)
            for node in parse_tree.expand_tree()
        ]
        if any(is_violateing):
            print('here')
        return not any(is_violateing)

    def is_projective_tree(self, parse_tree: Tree):
        is_violateing = [
            len(parse_tree.children(node)) > 2
            for node in parse_tree.expand_tree()
        ]
        if any(is_violateing):
            print('here')
        return not any(is_violateing)

    def _is_violating_node(self, node, parse_tree):
        """Checks id a node is violated - if its child's span is not contained in its parent span, or intersect another child."""
        node_span = parse_tree.get_node(node).data
        for child in parse_tree.children(node):
            child_span = child.data
            if not self._is_span_contained(
                    child_span, node_span):  # not contained in parent's span
                print('node {} is not contained in parent {}'.format(
                    child_span.to_string, node_span.to_string))
                return True
            for child_other in parse_tree.children(node):
                child_other_span = child_other.data
                if child != child_other and self._is_spans_intersect(
                        child_span,
                        child_other_span):  # intersects another span
                    print('node {} intersectes node {}'.format(
                        child_span.to_string, child_other_span.to_string))
                    return True
        return False

    def map_prog_to_tree(self, question, program):
        program = re.sub(r'(\w+) \(', r'( \1', program)
        self._program = program.replace(',', '')
        self._tokens = self._get_wordpieces(question)
        self._tree = Tree()
        self._decided_spans = []
        parse_result = self._parser.parseString(self._program)[0]
        return parse_result

        # # print the program tree
        # executor.parser.setParseAction(_parse_action_tree)
        # tree_parse = executor.parser.parseString(program)[0]
        # print('parse_tree=')
        # pprint(tree_parse)

    # def pprint(node, tab=""):
    #     if isinstance(node, str):
    #         print(tab + u"┗━ " + str(node))
    #         return
    #     print(tab + u"┗━ " + str(node.value))
    #     for child in node.children:
    #         pprint(child, tab + "    ")

    def _parse_action_tree(string, location, tokens):
        from collections import namedtuple
        Node = namedtuple("Node", ["value", "children"])
        node = Node(value=tokens[0][0], children=tokens[0][1:])
        return node

    def _get_aligned_span(self, subtree: Tree):
        """Gets hint span for aligning ambiguous constants (e.g., 'left' appears twice)"""
        return subtree.get_node(subtree.root).data

    def _get_first_argument_to_join(self, predicate_tree, arg1_tree,
                                    arg2_tree):
        predicate_span_start = int(predicate_tree.root.split('-')[0])
        arg1_span_start = int(arg1_tree.root.split('-')[0])
        arg2_span_start = int(arg2_tree.root.split('-')[0])
        if predicate_span_start < arg2_span_start < arg1_span_start or predicate_span_start > arg2_span_start > arg1_span_start:
            return arg2_tree, arg1_tree
        else:
            return arg1_tree, arg2_tree

    def _get_details(self, child, span_labels):
        data = child.data
        start = data.span[0]
        end = data.span[1] - 1
        span = (data.span[0], data.span[1] - 1)
        type = data.constant if data.constant else 'span'
        is_span = type == 'span'
        if not is_span:
            span_labels.append({'span': span, 'type': type})
        return start, end, is_span

    def _adjust_end(self, start, end, adjusted_end, is_span, span_labels):
        if end < adjusted_end - 1:
            span_labels.append({
                'span': (start, adjusted_end - 1),
                'type': 'span'
            })
        else:
            if is_span:
                span_labels.append({'span': (start, end), 'type': 'span'})

    def _inner_write(self, span_labels, children, end, parse_tree):
        children.sort(key=lambda c: c.data.span[0])
        start_1, end_1, is_span_1 = self._get_details(children[0], span_labels)
        start_2, end_2, is_span_2 = self._get_details(children[1], span_labels)
        if len(children) > 2:
            start_3, end_3, is_span_3 = self._get_details(
                children[2], span_labels)

        self._adjust_end(start_1, end_1, start_2, is_span_1, span_labels)

        if len(children) > 2:
            self._adjust_end(start_2, end_2, start_3, is_span_2, span_labels)
            self._adjust_end(start_3, end_3, end + 1, is_span_3, span_labels)
            # if end_2 < start_3 - 1:
            #     span_labels.append({'span': (start_2, start_3 - 1), 'type': 'span'})
            # if end_3 < end:
            #     span_labels.append({'span': (start_3, end), 'type': 'span'})
        else:
            self._adjust_end(start_2, end_2, end + 1, is_span_2, span_labels)
            # if end_2 < end:
            #     span_labels.append({'span': (start_2, end), 'type': 'span'})

        children_1 = parse_tree.children(children[0].identifier)
        if len(children_1) > 0:
            self._inner_write(span_labels, children_1, start_2 - 1, parse_tree)
        children_2 = parse_tree.children(children[1].identifier)
        if len(children) > 2:
            children_3 = parse_tree.children(children[1].identifier)
            if len(children_2) > 0:
                self._inner_write(span_labels, children_2, start_3 - 1,
                                  parse_tree)
            if len(children_3) > 0:
                self._inner_write(span_labels, children_3, end, parse_tree)
        else:
            if len(children_2) > 0:
                self._inner_write(span_labels, children_2, end, parse_tree)

    def write_to_output(self, line, parse_tree, output_file):
        tokens = self._get_wordpieces(line['question'])

        if line['question'] == "what state borders michigan ?":
            print()
        len_sent = len(tokens)

        span_labels = []

        # type = 'span'
        # span_labels.append({'span': (0, len_sent-1), 'type': type})
        root = parse_tree.root
        root_node = parse_tree.get_node(root).data

        # root_start, root_end = root.split('-')
        # root_start = int(root_start)
        # root_end = int(root_end)
        #
        # if root_start > 0:
        #
        start = 0
        end = len_sent - 1
        type = 'span'
        span_labels.append({'span': (start, end), 'type': type})

        children = parse_tree.children(root)

        if len(children) == 0:
            s, t = (int(root_node.span[0]), int(root_node.span[1]))
            type = root_node.constant
            span_labels.append({'span': (s, t), 'type': type})
            if s > 0:
                span_labels.append({'span': (s, end), 'type': 'span'})
        else:
            child_1_start = children[0].data.span[0]
            if child_1_start > 0:
                span_labels.append({
                    'span': (child_1_start, end),
                    'type': 'span'
                })
            self._inner_write(span_labels, parse_tree.children(root), end,
                              parse_tree)

        # while (len(parse_tree.children(root)) > 0):
        #     children = parse_tree.children(root)
        #     data_1 = children[0].data
        #     span_1 = (data_1.span[0], data_1.span[1] - 1)
        #     data_2 = children[1].data
        #     span_2 = (data_2.span[0], data_2.span[1] - 1)
        #     print()

        # for i, node in enumerate(parse_tree.expand_tree()):
        #     data = parse_tree.get_node(node).data
        #     span = (data.span[0], data.span[1]-1)  # move to inclusive spans
        #     if i==0:
        #         left_extra = None
        #         if span[0] > 0:
        #             left_extra = (0, span[0]-1)
        #         right_extra = None
        #         if span[1] < len_sent-1:
        #             left_right = (span[1]+1, len_sent-1)
        #
        #     type = data.constant if data.constant else 'span'
        #     span_labels.append({'span': span, 'type': type})
        line['gold_spans'] = span_labels
        json_str = json.dumps(line)
        output_file.write(json_str + '\n')