def test_truncate_window(self):
        tokenizer = BertPreTokenizer()

        sentence = "the quickest quick brown fox jumped over the lazy dog"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt"
        token_indexer = PretrainedBertIndexer(
            str(vocab_path), truncate_long_sequences=True, use_starting_offsets=True, max_pieces=10
        )

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab)

        # 16 = [CLS], 17 = [SEP]
        # 1 full window + 1 half window with start/end tokens
        assert indexed_tokens["input_ids"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 17]
        assert indexed_tokens["offsets"] == [1, 2, 4, 5, 6, 7, 8]
        assert indexed_tokens["token_type_ids"] == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

        token_indexer = PretrainedBertIndexer(
            str(vocab_path), truncate_long_sequences=True, use_starting_offsets=False, max_pieces=10
        )

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab)

        # 16 = [CLS], 17 = [SEP]
        # 1 full window + 1 half window with start/end tokens
        assert indexed_tokens["input_ids"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 17]
        assert indexed_tokens["offsets"] == [1, 3, 4, 5, 6, 7, 8]
Exemple #2
0
 def __init__(self, labels, domain_utils):
     self._tokenizer = BertPreTokenizer()  # should align with reader's tokenizer
     self._token_indexers = PretrainedBertIndexer(
         pretrained_model='bert-base-uncased')  # should align with reader's tokenizer
     self._is_cude = torch.cuda.device_count() > 0
     self._domain_utils = domain_utils
     self._set_labels_wordpieces(labels)
 def test_do_lower_case(self):
     # BertPreTokenizer makes every token not in `never_split` to lowercase by default
     word_tokenizer = BertPreTokenizer(never_split=["[UNUSED0]"])
     sentence = "[UNUSED0] [UNK] [unused0]"
     expected_tokens = ["[UNUSED0]", "[UNK]", "[", "unused0", "]"]
     tokens = [token.text for token in word_tokenizer.tokenize(sentence)]
     assert tokens == expected_tokens
    def test_truncate_window_fit_two_wordpieces(self):
        """
        Tests if the both `use_starting_offsets` options work properly when last
        word in the truncated sentence consists of two wordpieces.
        """

        tokenizer = BertPreTokenizer()

        sentence = "the quickest quick brown fox jumped over the quickest dog"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt"
        token_indexer = PretrainedBertIndexer(
            str(vocab_path), truncate_long_sequences=True, use_starting_offsets=True, max_pieces=13
        )

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab)

        # 16 = [CLS], 17 = [SEP]
        # 1 full window + 1 half window with start/end tokens
        assert indexed_tokens["input_ids"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 2, 3, 4, 17]
        assert indexed_tokens["offsets"] == [1, 2, 4, 5, 6, 7, 8, 9, 10]
        assert indexed_tokens["token_type_ids"] == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

        token_indexer = PretrainedBertIndexer(
            str(vocab_path), truncate_long_sequences=True, use_starting_offsets=False, max_pieces=13
        )

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab)

        # 16 = [CLS], 17 = [SEP]
        # 1 full window + 1 half window with start/end tokens
        assert indexed_tokens["input_ids"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 2, 3, 4, 17]
        assert indexed_tokens["offsets"] == [1, 3, 4, 5, 6, 7, 8, 9, 11]
    def test_end_to_end(self):
        tokenizer = BertPreTokenizer()

        #            2   3    4   3     5     6   8      9    2   14   12
        sentence1 = "the quickest quick brown fox jumped over the lazy dog"
        tokens1 = tokenizer.tokenize(sentence1)

        #            2   3     5     6   8      9    2  15 10 11 14   1
        sentence2 = "the quick brown fox jumped over the laziest lazy elmo"
        tokens2 = tokenizer.tokenize(sentence2)

        vocab = Vocabulary()

        instance1 = Instance(
            {"tokens": TextField(tokens1, {"bert": self.token_indexer})})
        instance2 = Instance(
            {"tokens": TextField(tokens2, {"bert": self.token_indexer})})

        batch = Batch([instance1, instance2])
        batch.index_instances(vocab)

        padding_lengths = batch.get_padding_lengths()
        tensor_dict = batch.as_tensor_dict(padding_lengths)
        tokens = tensor_dict["tokens"]

        # 16 = [CLS], 17 = [SEP]
        assert tokens["bert"].tolist() == [
            [16, 2, 3, 4, 3, 5, 6, 8, 9, 2, 14, 12, 17, 0],
            [16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17],
        ]

        assert tokens["bert-offsets"].tolist() == [
            [1, 3, 4, 5, 6, 7, 8, 9, 10, 11],
            [1, 2, 3, 4, 5, 6, 7, 10, 11, 12],
        ]

        # No offsets, should get 14 vectors back ([CLS] + 12 token wordpieces + [SEP])
        bert_vectors = self.token_embedder(tokens["bert"])
        assert list(bert_vectors.shape) == [2, 14, 12]

        # Offsets, should get 10 vectors back.
        bert_vectors = self.token_embedder(tokens["bert"],
                                           offsets=tokens["bert-offsets"])
        assert list(bert_vectors.shape) == [2, 10, 12]

        # Now try top_layer_only = True
        tlo_embedder = BertEmbedder(self.bert_model, top_layer_only=True)
        bert_vectors = tlo_embedder(tokens["bert"])
        assert list(bert_vectors.shape) == [2, 14, 12]

        bert_vectors = tlo_embedder(tokens["bert"],
                                    offsets=tokens["bert-offsets"])
        assert list(bert_vectors.shape) == [2, 10, 12]
 def __init__(self):
     self._tokenizer = BertPreTokenizer(
     )  # should align with reader's tokenizer
     self._token_indexers = PretrainedBertIndexer(
         pretrained_model='bert-base-uncased'
     )  # should align with reader's tokenizer
     self._synonyms = {'arguments': {}, 'predicates': {}}
     # self._load_argumets()
     # self._enrich_synonyms()
     # self._enrich_synonyms_by_hand()
     self._parser = pyparsing.nestedExpr(
         '(', ')', ignoreExpr=pyparsing.dblQuotedString)
    def test_sliding_window(self):
        tokenizer = BertPreTokenizer()

        sentence = "the quickest quick brown fox jumped over the lazy dog"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()

        vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt"
        token_indexer = PretrainedBertIndexer(str(vocab_path),
                                              truncate_long_sequences=False,
                                              max_pieces=8)

        config_path = self.FIXTURES_ROOT / "bert" / "config.json"
        config = BertConfig(str(config_path))
        bert_model = BertModel(config)
        token_embedder = BertEmbedder(bert_model, max_pieces=8)

        instance = Instance(
            {"tokens": TextField(tokens, {"bert": token_indexer})})

        batch = Batch([instance])
        batch.index_instances(vocab)

        padding_lengths = batch.get_padding_lengths()
        tensor_dict = batch.as_tensor_dict(padding_lengths)
        tokens = tensor_dict["tokens"]

        # 16 = [CLS], 17 = [SEP]
        # 1 full window + 1 half window with start/end tokens
        assert tokens["bert"].tolist() == [[
            16, 2, 3, 4, 3, 5, 6, 17, 16, 3, 5, 6, 8, 9, 2, 17, 16, 8, 9, 2,
            14, 12, 17
        ]]
        assert tokens["bert-offsets"].tolist() == [[
            1, 3, 4, 5, 6, 7, 8, 9, 10, 11
        ]]

        bert_vectors = token_embedder(tokens["bert"])
        assert list(bert_vectors.shape) == [1, 13, 12]

        # Testing without token_type_ids
        bert_vectors = token_embedder(tokens["bert"],
                                      offsets=tokens["bert-offsets"])
        assert list(bert_vectors.shape) == [1, 10, 12]

        # Testing with token_type_ids
        bert_vectors = token_embedder(tokens["bert"],
                                      offsets=tokens["bert-offsets"],
                                      token_type_ids=tokens["bert-type-ids"])
        assert list(bert_vectors.shape) == [1, 10, 12]
    def test_end_to_end_with_higher_order_inputs(self):
        tokenizer = BertPreTokenizer()

        #            2   3    4   3     5     6   8      9    2   14   12
        sentence1 = "the quickest quick brown fox jumped over the lazy dog"
        tokens1 = tokenizer.tokenize(sentence1)
        text_field1 = TextField(tokens1, {"bert": self.token_indexer})

        #            2   3     5     6   8      9    2  15 10 11 14   1
        sentence2 = "the quick brown fox jumped over the laziest lazy elmo"
        tokens2 = tokenizer.tokenize(sentence2)
        text_field2 = TextField(tokens2, {"bert": self.token_indexer})

        #            2   5    15 10 11 6
        sentence3 = "the brown laziest fox"
        tokens3 = tokenizer.tokenize(sentence3)
        text_field3 = TextField(tokens3, {"bert": self.token_indexer})

        vocab = Vocabulary()

        instance1 = Instance({"tokens": ListField([text_field1])})
        instance2 = Instance({"tokens": ListField([text_field2, text_field3])})

        batch = Batch([instance1, instance2])
        batch.index_instances(vocab)

        padding_lengths = batch.get_padding_lengths()
        tensor_dict = batch.as_tensor_dict(padding_lengths, verbose=True)
        tokens = tensor_dict["tokens"]

        # No offsets, should get 14 vectors back ([CLS] + 12 wordpieces + [SEP])
        bert_vectors = self.token_embedder(tokens["bert"])
        assert list(bert_vectors.shape) == [2, 2, 14, 12]

        # Offsets, should get 10 vectors back.
        bert_vectors = self.token_embedder(tokens["bert"],
                                           offsets=tokens["bert-offsets"])
        assert list(bert_vectors.shape) == [2, 2, 10, 12]

        # Now try top_layer_only = True
        tlo_embedder = BertEmbedder(self.bert_model, top_layer_only=True)
        bert_vectors = tlo_embedder(tokens["bert"])
        assert list(bert_vectors.shape) == [2, 2, 14, 12]

        bert_vectors = tlo_embedder(tokens["bert"],
                                    offsets=tokens["bert-offsets"])
        assert list(bert_vectors.shape) == [2, 2, 10, 12]
Exemple #9
0
class ZeroShotExtractor:

    def __init__(self, labels, domain_utils):
        self._tokenizer = BertPreTokenizer()  # should align with reader's tokenizer
        self._token_indexers = PretrainedBertIndexer(
            pretrained_model='bert-base-uncased')  # should align with reader's tokenizer
        self._is_cude = torch.cuda.device_count() > 0
        self._domain_utils = domain_utils
        self._set_labels_wordpieces(labels)

    def _get_wordpieces(self, text): # should match the reader method
        tokens = self._tokenizer.tokenize(text)
        do_lowercase = True
        tokens_out = (
            token.text.lower()
            if do_lowercase and token.text not in self._tokenizer.never_split
            else token.text
            for token in tokens
        )
        wps = [
            [wordpiece for wordpiece in self._token_indexers.wordpiece_tokenizer(token)]
            for token in tokens_out
        ]
        wps_flat = [wordpiece for token in wps for wordpiece in token]
        return tuple(wps_flat)

    def _set_labels_wordpieces(self, labels):
        self._num_labels = len(labels)
        self._labels_wordpieces = defaultdict(list)
        for index, label in labels.items():
            if label == 'NO-LABEL' or label == 'span':
                continue
            lexicon_phrases = self._domain_utils.get_lexicon_phrase(label)
            for lexicon_phrase in lexicon_phrases:
                self._labels_wordpieces[self._get_wordpieces(lexicon_phrase)].append(index)

    def get_similarity_features(self, batch_tokens, batch_spans):
        device = 'cuda' if self._is_cude else 'cpu'
        similarities = torch.zeros([batch_spans.shape[0], batch_spans.shape[1], self._num_labels], dtype=torch.float32,
                                   requires_grad=False, device=device)

        for k, (sentence, spans) in enumerate(zip(batch_tokens, batch_spans)):
            sent_len = len(sentence)
            span_to_ind = {}
            for i, span in enumerate(spans):
                span_to_ind[tuple(span.tolist())] = i
            for i in range(sent_len):
                for j in range(i+1, i+6):
                    if j > sent_len:
                        break
                    labels = self._labels_wordpieces.get(sentence[i:j])
                    if labels:
                        start = i + 1
                        end = j
                        for label in labels:
                            similarities[k, span_to_ind[(start, end)], label] = 1.0
        return similarities
    def test_sliding_window_with_batch(self):
        tokenizer = BertPreTokenizer()

        sentence = "the quickest quick brown fox jumped over the lazy dog"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()

        vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt"
        token_indexer = PretrainedBertIndexer(str(vocab_path),
                                              truncate_long_sequences=False,
                                              max_pieces=8)

        config_path = self.FIXTURES_ROOT / "bert" / "config.json"
        config = BertConfig(str(config_path))
        bert_model = BertModel(config)
        token_embedder = BertEmbedder(bert_model, max_pieces=8)

        instance = Instance(
            {"tokens": TextField(tokens, {"bert": token_indexer})})
        instance2 = Instance({
            "tokens":
            TextField(tokens + tokens + tokens, {"bert": token_indexer})
        })

        batch = Batch([instance, instance2])
        batch.index_instances(vocab)

        padding_lengths = batch.get_padding_lengths()
        tensor_dict = batch.as_tensor_dict(padding_lengths)
        tokens = tensor_dict["tokens"]

        # Testing without token_type_ids
        bert_vectors = token_embedder(tokens["bert"],
                                      offsets=tokens["bert-offsets"])
        assert bert_vectors is not None

        # Testing with token_type_ids
        bert_vectors = token_embedder(tokens["bert"],
                                      offsets=tokens["bert-offsets"],
                                      token_type_ids=tokens["bert-type-ids"])
        assert bert_vectors is not None
Exemple #11
0
    def test_max_length(self):
        config = BertConfig(len(self.token_indexer.vocab))
        model = BertModel(config)
        embedder = BertEmbedder(model)

        tokenizer = BertPreTokenizer()
        sentence = "the " * 1000
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()

        instance = Instance({"tokens": TextField(tokens, {"bert": self.token_indexer})})

        batch = Batch([instance])
        batch.index_instances(vocab)

        padding_lengths = batch.get_padding_lengths()
        tensor_dict = batch.as_tensor_dict(padding_lengths)
        tokens = tensor_dict["tokens"]
        embedder(tokens["bert"], tokens["bert-offsets"])
Exemple #12
0
    def test_padding_for_equal_length_indices(self):
        tokenizer = BertPreTokenizer()

        #            2   3     5     6   8      9    2   14   12
        sentence = "the quick brown fox jumped over the lazy dog"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()

        instance = Instance({"tokens": TextField(tokens, {"bert": self.token_indexer})})

        batch = Batch([instance])
        batch.index_instances(vocab)

        padding_lengths = batch.get_padding_lengths()
        tensor_dict = batch.as_tensor_dict(padding_lengths)
        tokens = tensor_dict["tokens"]

        assert tokens["bert"].tolist() == [[16, 2, 3, 5, 6, 8, 9, 2, 14, 12, 17]]

        assert tokens["bert-offsets"].tolist() == [[1, 2, 3, 4, 5, 6, 7, 8, 9]]
    def test_truncate_window_dont_split_wordpieces(self):
        """
        Tests if the sentence is not truncated inside of the word with 2 or
        more wordpieces.
        """

        tokenizer = BertPreTokenizer()

        sentence = "the quickest quick brown fox jumped over the quickest dog"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt"
        token_indexer = PretrainedBertIndexer(
            str(vocab_path), truncate_long_sequences=True, use_starting_offsets=True, max_pieces=12
        )

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab)

        # 16 = [CLS], 17 = [SEP]
        # 1 full window + 1 half window with start/end tokens
        assert indexed_tokens["input_ids"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 2, 17]
        # We could fit one more piece here, but we don't, not to have a cut
        # in the middle of the word
        assert indexed_tokens["offsets"] == [1, 2, 4, 5, 6, 7, 8, 9]
        assert indexed_tokens["token_type_ids"] == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

        token_indexer = PretrainedBertIndexer(
            str(vocab_path), truncate_long_sequences=True, use_starting_offsets=False, max_pieces=12
        )

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab)

        # 16 = [CLS], 17 = [SEP]
        # 1 full window + 1 half window with start/end tokens
        assert indexed_tokens["input_ids"] == [16, 2, 3, 4, 3, 5, 6, 8, 9, 2, 17]
        # We could fit one more piece here, but we don't, not to have a cut
        # in the middle of the word
        assert indexed_tokens["offsets"] == [1, 3, 4, 5, 6, 7, 8, 9]
    def test_starting_ending_offsets(self):
        tokenizer = BertPreTokenizer()

        #           2   3     5     6   8      9    2  15 10 11 14   1
        sentence = "the quick brown fox jumped over the laziest lazy elmo"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt"
        token_indexer = PretrainedBertIndexer(str(vocab_path))

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab)

        # 16 = [CLS], 17 = [SEP]
        assert indexed_tokens["input_ids"] == [16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17]
        assert indexed_tokens["offsets"] == [1, 2, 3, 4, 5, 6, 7, 10, 11, 12]

        token_indexer = PretrainedBertIndexer(str(vocab_path), use_starting_offsets=True)

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab)

        assert indexed_tokens["input_ids"] == [16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17]
        assert indexed_tokens["offsets"] == [1, 2, 3, 4, 5, 6, 7, 8, 11, 12]
class TestBertPreTokenizer(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        self.word_tokenizer = BertPreTokenizer()

    def test_never_split(self):
        sentence = "[unused0] [UNK] [SEP] [PAD] [CLS] [MASK]"
        expected_tokens = ["[", "unused0", "]", "[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"]
        tokens = [token.text for token in self.word_tokenizer.tokenize(sentence)]
        assert tokens == expected_tokens

    def test_do_lower_case(self):
        # BertPreTokenizer makes every token not in `never_split` to lowercase by default
        word_tokenizer = BertPreTokenizer(never_split=["[UNUSED0]"])
        sentence = "[UNUSED0] [UNK] [unused0]"
        expected_tokens = ["[UNUSED0]", "[UNK]", "[", "unused0", "]"]
        tokens = [token.text for token in word_tokenizer.tokenize(sentence)]
        assert tokens == expected_tokens
class SpanMapper(object):
    def __init__(self):
        self._tokenizer = BertPreTokenizer(
        )  # should align with reader's tokenizer
        self._token_indexers = PretrainedBertIndexer(
            pretrained_model='bert-base-uncased'
        )  # should align with reader's tokenizer
        self._synonyms = {'arguments': {}, 'predicates': {}}
        # self._load_argumets()
        # self._enrich_synonyms()
        # self._enrich_synonyms_by_hand()
        self._parser = pyparsing.nestedExpr(
            '(', ')', ignoreExpr=pyparsing.dblQuotedString)
        # self._parser.setParseAction(self._parse_action)

    def _parse_action(self, string, location, tokens) -> Tree:
        raise NotImplementedError

    def _get_wordpieces(self, text):  # should match the reader method
        tokens = self._tokenizer.tokenize(text)
        do_lowercase = True
        tokens_out = (token.text.lower() if do_lowercase
                      and token.text not in self._tokenizer.never_split else
                      token.text for token in tokens)
        wps = [[
            wordpiece
            for wordpiece in self._token_indexers.wordpiece_tokenizer(token)
        ] for token in tokens_out]
        wps_flat = [wordpiece for token in wps for wordpiece in token]
        return tuple(wps_flat)

    def _align_to_text(self, constant, type):
        spans = []
        realizations = self._synonyms[type][constant]
        num_tokens_text = len(self._tokens)
        for realization in realizations:
            num_tokens_real = len(realization)
            for begin in range(num_tokens_text - num_tokens_real + 1):
                end = begin + num_tokens_real
                if self._tokens[begin:end] == realization:
                    span = Span(height=1,
                                span=(begin, end),
                                content=self._tokens[begin:end],
                                constant=constant)
                    spans.append(span)
        # assert len(spans) == 1
        return spans

    def _align_filter_to_text(self, constants, type, hint_span=None):
        spans_per_constant = []
        for i, constant in enumerate(constants):
            constant_spans = self._align_to_text(
                constant, type)  # find spans in text for a particular constant
            # if len(constants) == 1 and hint_span != None:
            # # if i == 0 and hint_span != None:
            #     constant_spans_ = self._choose_span_by_hint(constant_spans, hint_span)
            #     if constant_spans_ != constant_spans:
            #         print('here')
            #         constant_spans = constant_spans_
            # if len(constants) == 2 and hint_span != None and i == 1 and len(constant_spans) > 1:
            #     constant_spans__ = []
            #     for j, span in enumerate(constant_spans):
            #         # temporary bad way to disambiguate spans
            #         if span.span[0] < 4 and hint_span.span[0] <= 6:
            #             constant_spans__.append(span)
            #     if len(constant_spans__) > 0:
            #         print('here') # risky
            #         constant_spans = constant_spans__
            # assert len(constant_spans) >= 1
            spans_per_constant.append(constant_spans)
        contiguous_chains = self._find_contiguous_chain(spans_per_constant)
        # if len(contiguous_chains) > 1 and hint_span:
        #     contiguous_chains_ = self._filter_chains_with_hint(contiguous_chains, hint_span)
        #     if contiguous_chains_ != contiguous_chains:
        #         print('here') # risky
        #     contiguous_chains = contiguous_chains_
        if len(contiguous_chains) == 2:
            if not self._possible_chains:
                next_option = tuple(
                    [span.constant for span in contiguous_chains[1]])
                self._possible_chains = contiguous_chains[1]
                contiguous_chains = [contiguous_chains[0]]
            else:
                contiguous_chains = [self._possible_chains]
                self._possible_chains = None
        if len(contiguous_chains) != 1:
            print('here')
        assert len(contiguous_chains) == 1
        contiguous_chain = contiguous_chains[0]
        return self._contiguous_chain_to_subtree(contiguous_chain)

    def _filter_chains_with_hint(self, contiguous_chains, hint_span):
        # min delta from span edges. Sets a high values if span is within the hint_span
        delta_from_hint = [
            min(abs(chain[0].span[0] - hint_span.span[1]),
                abs(chain[-1].span[1] - hint_span.span[0]))
            if not self._is_span_contained(chain[0], hint_span) else 10000
            for chain in contiguous_chains
        ]
        min_value = min(delta_from_hint)
        min_chains = []
        for i, chain in enumerate(contiguous_chains):
            if delta_from_hint[i] == min_value:
                min_chains.append(chain)
        # if min_value == 0 and len(min_spans) > 1:  # take the last span if all min_values are 0
        #     min_spans = [sorted(spans, key=lambda span: span.span[0], reverse=True)[0]]
        return min_chains

    def _filter_sub_chains(self, contiguous_chains):
        """Filter chains that are actually a part of a larger chain"""
        full_contiguous_chains = []
        tokens = self._tokens
        for chain in contiguous_chains:
            start = chain[0].span[0]
            end = chain[-1].span[1]
            if start > 0:  # there is at least one token before 'start'
                if tuple(tokens[start - 1:start]) in self._filter_args:
                    continue
            if start > 1:  # there are at least two tokens before 'start'
                if tuple(tokens[start - 2:start]) in self._filter_args:
                    continue
            if tuple(self._tokens[end:end + 1]) in self._filter_args:
                continue
            if tuple(self._tokens[end:end + 2]) in self._filter_args:
                continue
            full_contiguous_chains.append(chain)
        return full_contiguous_chains

    def _find_contiguous_chain(self, spans_per_constant: List[List[Span]]):
        contiguous_chains = []
        combinations = list(product(*spans_per_constant))
        for comb in combinations:
            if all([
                    s.span[1] == comb[i + 1].span[0]
                    for i, s in enumerate(comb[:-1])
            ]):  # check if contiguous
                if not any([
                        self._is_span_contained(sub_span, span)
                        for sub_span in comb for span in self._decided_spans
                ]):  # check if span wasn't decided before
                    contiguous_chains.append(comb)
                else:
                    print('here')
        if len(contiguous_chains) > 1:
            contiguous_chains_ = self._filter_sub_chains(contiguous_chains)
            if contiguous_chains != contiguous_chains_:
                print('here')
            contiguous_chains = contiguous_chains_
        return contiguous_chains

    def _contiguous_chain_to_subtree(self, contiguous_chain: List[Span]):
        self._decided_spans += contiguous_chain
        tree = Tree()
        stack = []
        for i in range(len(contiguous_chain) - 1):
            # make parent
            start = contiguous_chain[i].span[0]
            end = contiguous_chain[-1].span[1]
            span = Span(height=len(contiguous_chain) - 1,
                        span=(start, end),
                        content=self._tokens[start:end],
                        constant=None)
            parent = stack[-1] if len(stack) > 0 else None
            identifier = '{}-{}'.format(start, end)
            tree.create_node(identifier=identifier, data=span, parent=parent)
            stack.append(identifier)

            # make left child
            span_lc = contiguous_chain[i]
            identifier_lc = '{}-{}'.format(span_lc.span[0], span_lc.span[1])
            tree.create_node(identifier=identifier_lc,
                             data=span_lc,
                             parent=identifier)

        # make last right child
        span_rc = contiguous_chain[-1]
        identifier_rc = '{}-{}'.format(span_rc.span[0], span_rc.span[1])
        tree.create_node(identifier=identifier_rc,
                         data=span_rc,
                         parent=stack[-1] if stack else None)

        return tree  # return top most identifier

    def _join_trees(self, subtree_1: Tree, subtree_2: Tree):

        top_tree = Tree()
        arg_1_span = subtree_1.root.split('-')
        arg_2_span = subtree_2.root.split('-')
        start = int(arg_1_span[0])
        end = int(arg_2_span[1])
        identifier = '{}-{}'.format(start, end)
        span = Span(height=100,
                    span=(start, end),
                    content=self._tokens[start:end],
                    constant=None)
        top_tree.create_node(identifier=identifier, data=span)
        top_tree.paste(nid=identifier, new_tree=subtree_1)
        top_tree.paste(nid=identifier, new_tree=subtree_2)
        return top_tree

    def _combine_trees(self, subtree_1: Tree, subtree_2: Tree):
        subtree_2.paste(nid=subtree_2.root, new_tree=subtree_1)
        return subtree_2

    def _join_binary_predicate_tree(self,
                                    predicate: Tree,
                                    arg_1: Tree,
                                    arg_2: Tree,
                                    allow_arg_switch: bool = True):
        predicate_start = int(predicate.root.split('-')[0])
        arg_1_start = int(arg_1.root.split('-')[0])
        arg_2_start = int(arg_2.root.split('-')[0])

        # if arg_2 is not in the middle
        if not (arg_2_start > predicate_start and arg_2_start < arg_1_start
                ) and not (arg_2_start < predicate_start
                           and arg_2_start > arg_1_start):
            if predicate_start < arg_1_start:  # predicate is left to arg_1
                join_1 = self._join_trees(predicate, arg_1)
            else:
                join_1 = self._join_trees(arg_1, predicate)

            if predicate_start < arg_2_start:  # predicate is left to arg_2
                join_2 = self._join_trees(join_1, arg_2)
            else:
                join_2 = self._join_trees(arg_2, join_1)
            return join_2
        else:
            if not allow_arg_switch:  # in this case we do not allow arg_2 to be in a span with the predicate
                raise Exception(
                    'Argument switch is not allowed={}'.format(predicate))
            if predicate_start < arg_1_start:  # predicate is left to arg_1, and arg_2 is in the middle
                join_1 = self._join_trees(predicate, arg_2)
                join_2 = self._join_trees(join_1, arg_1)
            else:
                join_1 = self._join_trees(arg_2, predicate)
                join_2 = self._join_trees(arg_1, join_1)
            return join_2

    def _join_unary_predicate_tree(self, predicate: Tree, arg: Tree):
        predicate_start = int(predicate.root.split('-')[0])
        arg_start = int(arg.root.split('-')[0])

        if predicate_start < arg_start:  # predicate is left to arg_1
            join_tree = self._join_trees(predicate, arg)
        else:
            join_tree = self._join_trees(arg, predicate)

        return join_tree

    def _filter_contained_spans(self, spans):
        spans_ = sorted(spans,
                        key=lambda span: span.span[1] - span.span[0],
                        reverse=True)  # sort according to span length
        if spans_ != spans:
            print('here')
        spans = spans_
        filtered_spans = []
        for span in spans:
            for broad_span in filtered_spans:
                if self._is_span_contained(
                        span, broad_span):  # span contained in broad_span
                    break
            else:
                filtered_spans.append(span)
        return filtered_spans

    def _is_span_contained(self, span_1: Span, span_2: Span):
        return set(range(span_1.span[0], span_1.span[1])).issubset(
            set(range(span_2.span[0], span_2.span[1])))

    def _is_spans_intersect(self, span_1: Span, span_2: Span):
        return len(
            set(range(span_1.span[0], span_1.span[1])).intersection(
                set(range(span_2.span[0], span_2.span[1])))) > 0

    def _choose_span_by_hint(self, spans, hint_span):
        """Chooses that span that is closest to the hint span. The hint span is the one the selected span should be close to."""

        # min delta from span edges. Sets a high values if span is within the hint_span
        delta_from_hint = [
            min(abs(span.span[0] -
                    hint_span.span[1]), abs(span.span[1] - hint_span.span[0]))
            if not self._is_span_contained(span, hint_span) else 10000
            for span in spans
        ]
        min_value = min(delta_from_hint)
        min_spans = []
        for i, span in enumerate(spans):
            if delta_from_hint[i] == min_value:
                min_spans.append(span)
        if min_value == 0 and len(
                min_spans) > 1:  # take the last span if all min_values are 0
            min_spans = [
                sorted(spans, key=lambda span: span.span[0], reverse=True)[0]
            ]  # risky
        return min_spans

    def _get_tree_from_constant(self,
                                constant,
                                type,
                                hint_span=None,
                                constant_prefix=None):
        spans = self._align_to_text(constant, type)
        spans_ = self._filter_contained_spans(spans)
        if spans != spans_:
            print('here')
        spans = spans_
        if len(spans) != 1:
            print('here')
        # if len(spans) > 1 and hint_span:
        #     spans__ = self._choose_span_by_hint(spans, hint_span)
        #     if spans != spans__:
        #         print('here')
        #     spans = spans__
        spans___ = []
        for span in spans:
            if not any([
                    self._is_span_contained(span, decided_span)
                    for decided_span in self._decided_spans
            ]):
                spans___.append(span)
        if spans___ != spans:
            print('here')
        spans = spans___
        if len(spans) == 2:
            if not constant in self._possible_constants:
                self._possible_constants[constant] = spans[1]
                spans = [spans[0]]
            else:
                spans = [self._possible_constants[constant]]
                del self._possible_constants[constant]
        if len(spans) != 1:
            print('spans for {} are {}'.format(constant, spans))
        assert len(spans) == 1
        span = spans[0]
        if constant_prefix:
            span.constant = '{}#{}'.format(constant_prefix, span.constant)
        self._decided_spans.append(span)
        identifier = '{}-{}'.format(span.span[0], span.span[1])
        constant_tree = Tree()
        constant_tree.create_node(identifier=identifier, data=span)
        return constant_tree

    def is_valid_tree(self, parse_tree: Tree):
        is_violateing = [
            self._is_violating_node(node, parse_tree)
            for node in parse_tree.expand_tree()
        ]
        if any(is_violateing):
            print('here')
        return not any(is_violateing)

    def is_projective_tree(self, parse_tree: Tree):
        is_violateing = [
            len(parse_tree.children(node)) > 2
            for node in parse_tree.expand_tree()
        ]
        if any(is_violateing):
            print('here')
        return not any(is_violateing)

    def _is_violating_node(self, node, parse_tree):
        """Checks id a node is violated - if its child's span is not contained in its parent span, or intersect another child."""
        node_span = parse_tree.get_node(node).data
        for child in parse_tree.children(node):
            child_span = child.data
            if not self._is_span_contained(
                    child_span, node_span):  # not contained in parent's span
                print('node {} is not contained in parent {}'.format(
                    child_span.to_string, node_span.to_string))
                return True
            for child_other in parse_tree.children(node):
                child_other_span = child_other.data
                if child != child_other and self._is_spans_intersect(
                        child_span,
                        child_other_span):  # intersects another span
                    print('node {} intersectes node {}'.format(
                        child_span.to_string, child_other_span.to_string))
                    return True
        return False

    def map_prog_to_tree(self, question, program):
        program = re.sub(r'(\w+) \(', r'( \1', program)
        self._program = program.replace(',', '')
        self._tokens = self._get_wordpieces(question)
        self._tree = Tree()
        self._decided_spans = []
        parse_result = self._parser.parseString(self._program)[0]
        return parse_result

        # # print the program tree
        # executor.parser.setParseAction(_parse_action_tree)
        # tree_parse = executor.parser.parseString(program)[0]
        # print('parse_tree=')
        # pprint(tree_parse)

    # def pprint(node, tab=""):
    #     if isinstance(node, str):
    #         print(tab + u"┗━ " + str(node))
    #         return
    #     print(tab + u"┗━ " + str(node.value))
    #     for child in node.children:
    #         pprint(child, tab + "    ")

    def _parse_action_tree(string, location, tokens):
        from collections import namedtuple
        Node = namedtuple("Node", ["value", "children"])
        node = Node(value=tokens[0][0], children=tokens[0][1:])
        return node

    def _get_aligned_span(self, subtree: Tree):
        """Gets hint span for aligning ambiguous constants (e.g., 'left' appears twice)"""
        return subtree.get_node(subtree.root).data

    def _get_first_argument_to_join(self, predicate_tree, arg1_tree,
                                    arg2_tree):
        predicate_span_start = int(predicate_tree.root.split('-')[0])
        arg1_span_start = int(arg1_tree.root.split('-')[0])
        arg2_span_start = int(arg2_tree.root.split('-')[0])
        if predicate_span_start < arg2_span_start < arg1_span_start or predicate_span_start > arg2_span_start > arg1_span_start:
            return arg2_tree, arg1_tree
        else:
            return arg1_tree, arg2_tree

    def _get_details(self, child, span_labels):
        data = child.data
        start = data.span[0]
        end = data.span[1] - 1
        span = (data.span[0], data.span[1] - 1)
        type = data.constant if data.constant else 'span'
        is_span = type == 'span'
        if not is_span:
            span_labels.append({'span': span, 'type': type})
        return start, end, is_span

    def _adjust_end(self, start, end, adjusted_end, is_span, span_labels):
        if end < adjusted_end - 1:
            span_labels.append({
                'span': (start, adjusted_end - 1),
                'type': 'span'
            })
        else:
            if is_span:
                span_labels.append({'span': (start, end), 'type': 'span'})

    def _inner_write(self, span_labels, children, end, parse_tree):
        children.sort(key=lambda c: c.data.span[0])
        start_1, end_1, is_span_1 = self._get_details(children[0], span_labels)
        start_2, end_2, is_span_2 = self._get_details(children[1], span_labels)
        if len(children) > 2:
            start_3, end_3, is_span_3 = self._get_details(
                children[2], span_labels)

        self._adjust_end(start_1, end_1, start_2, is_span_1, span_labels)

        if len(children) > 2:
            self._adjust_end(start_2, end_2, start_3, is_span_2, span_labels)
            self._adjust_end(start_3, end_3, end + 1, is_span_3, span_labels)
            # if end_2 < start_3 - 1:
            #     span_labels.append({'span': (start_2, start_3 - 1), 'type': 'span'})
            # if end_3 < end:
            #     span_labels.append({'span': (start_3, end), 'type': 'span'})
        else:
            self._adjust_end(start_2, end_2, end + 1, is_span_2, span_labels)
            # if end_2 < end:
            #     span_labels.append({'span': (start_2, end), 'type': 'span'})

        children_1 = parse_tree.children(children[0].identifier)
        if len(children_1) > 0:
            self._inner_write(span_labels, children_1, start_2 - 1, parse_tree)
        children_2 = parse_tree.children(children[1].identifier)
        if len(children) > 2:
            children_3 = parse_tree.children(children[1].identifier)
            if len(children_2) > 0:
                self._inner_write(span_labels, children_2, start_3 - 1,
                                  parse_tree)
            if len(children_3) > 0:
                self._inner_write(span_labels, children_3, end, parse_tree)
        else:
            if len(children_2) > 0:
                self._inner_write(span_labels, children_2, end, parse_tree)

    def write_to_output(self, line, parse_tree, output_file):
        tokens = self._get_wordpieces(line['question'])

        if line['question'] == "what state borders michigan ?":
            print()
        len_sent = len(tokens)

        span_labels = []

        # type = 'span'
        # span_labels.append({'span': (0, len_sent-1), 'type': type})
        root = parse_tree.root
        root_node = parse_tree.get_node(root).data

        # root_start, root_end = root.split('-')
        # root_start = int(root_start)
        # root_end = int(root_end)
        #
        # if root_start > 0:
        #
        start = 0
        end = len_sent - 1
        type = 'span'
        span_labels.append({'span': (start, end), 'type': type})

        children = parse_tree.children(root)

        if len(children) == 0:
            s, t = (int(root_node.span[0]), int(root_node.span[1]))
            type = root_node.constant
            span_labels.append({'span': (s, t), 'type': type})
            if s > 0:
                span_labels.append({'span': (s, end), 'type': 'span'})
        else:
            child_1_start = children[0].data.span[0]
            if child_1_start > 0:
                span_labels.append({
                    'span': (child_1_start, end),
                    'type': 'span'
                })
            self._inner_write(span_labels, parse_tree.children(root), end,
                              parse_tree)

        # while (len(parse_tree.children(root)) > 0):
        #     children = parse_tree.children(root)
        #     data_1 = children[0].data
        #     span_1 = (data_1.span[0], data_1.span[1] - 1)
        #     data_2 = children[1].data
        #     span_2 = (data_2.span[0], data_2.span[1] - 1)
        #     print()

        # for i, node in enumerate(parse_tree.expand_tree()):
        #     data = parse_tree.get_node(node).data
        #     span = (data.span[0], data.span[1]-1)  # move to inclusive spans
        #     if i==0:
        #         left_extra = None
        #         if span[0] > 0:
        #             left_extra = (0, span[0]-1)
        #         right_extra = None
        #         if span[1] < len_sent-1:
        #             left_right = (span[1]+1, len_sent-1)
        #
        #     type = data.constant if data.constant else 'span'
        #     span_labels.append({'span': span, 'type': type})
        line['gold_spans'] = span_labels
        json_str = json.dumps(line)
        output_file.write(json_str + '\n')
    def test_sliding_window(self):
        tokenizer = BertPreTokenizer()

        sentence = "the quickest quick brown [SEP] jumped over the lazy dog"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / "bert" / "vocab.txt"
        token_indexer = PretrainedBertIndexer(
            str(vocab_path),
            truncate_long_sequences=False,
            use_starting_offsets=False,
            max_pieces=10,
        )

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab)

        # 16 = [CLS], 17 = [SEP]
        # 1 full window + 1 half window with start/end tokens
        # [CLS] the quick est quick brown [SEP] jumped over [SEP]
        assert indexed_tokens["input_ids"] == [
            16,
            2,
            3,
            4,
            3,
            5,
            17,
            8,
            9,
            17,
            # [CLS] brown [SEP] jumped over the lazy dog [SEP]
            16,
            5,
            17,
            8,
            9,
            2,
            14,
            12,
            17,
        ]
        assert indexed_tokens["offsets"] == [1, 3, 4, 5, 6, 7, 8, 9, 10, 11]

        # The extra [SEP]s shouldn't pollute the token-type-ids
        # [CLS] the quick est quick brown [SEP] jumped over [SEP]
        assert indexed_tokens["token_type_ids"] == [
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            1,
            1,
            1,
            # [CLS] brown [SEP] jumped over the lazy dog [SEP]
            0,
            0,
            0,
            1,
            1,
            1,
            1,
            1,
            1,
        ]
 def setUp(self):
     super().setUp()
     self.word_tokenizer = BertPreTokenizer()