예제 #1
0
    def test_padding_for_equal_length_indices(self):
        tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter())

        #            2   3     5     6   8      9    2   14   12
        sentence = "the quick brown fox jumped over the lazy dog"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()

        instance = Instance({"tokens": TextField(tokens, {"bert": self.token_indexer})})

        batch = Batch([instance])
        batch.index_instances(vocab)

        padding_lengths = batch.get_padding_lengths()
        tensor_dict = batch.as_tensor_dict(padding_lengths)
        tokens = tensor_dict["tokens"]

        assert tokens["bert"].tolist() == [
                [2, 3, 5, 6, 8, 9, 2, 14, 12]
        ]

        assert tokens["bert-offsets"].tolist() == [
                [0, 1, 2, 3, 4, 5, 6, 7, 8]
        ]
 def test_passes_through_correctly(self):
     tokenizer = WordTokenizer(start_tokens=['@@', '%%'], end_tokens=['^^'])
     sentence = "this (sentence) has 'crazy' \"punctuation\"."
     tokens = [t.text for t in tokenizer.tokenize(sentence)]
     expected_tokens = ["@@", "%%", "this", "(", "sentence", ")", "has", "'", "crazy", "'", "\"",
                        "punctuation", "\"", ".", "^^"]
     assert tokens == expected_tokens
예제 #3
0
    def test_never_lowercase(self):
        # Our default tokenizer doesn't handle lowercasing.
        tokenizer = WordTokenizer()

        #            2 15 10 11  6
        sentence = "the laziest fox"

        tokens = tokenizer.tokenize(sentence)
        tokens.append(Token("[PAD]"))  # have to do this b/c tokenizer splits it in three

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
        token_indexer = PretrainedBertIndexer(str(vocab_path), do_lowercase=True)

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        # PAD should get recognized and not lowercased      # [PAD]
        assert indexed_tokens["bert"] == [16, 2, 15, 10, 11, 6, 0, 17]

        # Unless we manually override the never lowercases
        token_indexer = PretrainedBertIndexer(str(vocab_path), do_lowercase=True, never_lowercase=())
        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        # now PAD should get lowercased and be UNK          # [UNK]
        assert indexed_tokens["bert"] == [16, 2, 15, 10, 11, 6, 1, 17]
예제 #4
0
def search(tables_directory: str,
           input_examples_file: str,
           output_path: str,
           max_path_length: int,
           max_num_logical_forms: int,
           use_agenda: bool,
           output_separate_files: bool) -> None:
    data = [wikitables_util.parse_example_line(example_line) for example_line in
            open(input_examples_file)]
    tokenizer = WordTokenizer()
    if output_separate_files and not os.path.exists(output_path):
        os.makedirs(output_path)
    if not output_separate_files:
        output_file_pointer = open(output_path, "w")
    for instance_data in data:
        utterance = instance_data["question"]
        question_id = instance_data["id"]
        if utterance.startswith('"') and utterance.endswith('"'):
            utterance = utterance[1:-1]
        # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged
        table_file = instance_data["table_filename"].replace("csv", "tagged")
        target_list = instance_data["target_values"]
        tokenized_question = tokenizer.tokenize(utterance)
        table_file = f"{tables_directory}/{table_file}"
        context = TableQuestionContext.read_from_file(table_file, tokenized_question)
        world = WikiTablesVariableFreeWorld(context)
        walker = ActionSpaceWalker(world, max_path_length=max_path_length)
        correct_logical_forms = []
        if use_agenda:
            agenda = world.get_agenda()
            all_logical_forms = walker.get_logical_forms_with_agenda(agenda=agenda,
                                                                     max_num_logical_forms=10000)
        else:
            all_logical_forms = walker.get_all_logical_forms(max_num_logical_forms=10000)
        for logical_form in all_logical_forms:
            if world.evaluate_logical_form(logical_form, target_list):
                correct_logical_forms.append(logical_form)
        if output_separate_files and correct_logical_forms:
            with gzip.open(f"{output_path}/{question_id}.gz", "wt") as output_file_pointer:
                for logical_form in correct_logical_forms:
                    print(logical_form, file=output_file_pointer)
        elif not output_separate_files:
            print(f"{question_id} {utterance}", file=output_file_pointer)
            if use_agenda:
                print(f"Agenda: {agenda}", file=output_file_pointer)
            if not correct_logical_forms:
                print("NO LOGICAL FORMS FOUND!", file=output_file_pointer)
            for logical_form in correct_logical_forms[:max_num_logical_forms]:
                print(logical_form, file=output_file_pointer)
            print(file=output_file_pointer)
    if not output_separate_files:
        output_file_pointer.close()
 def test_batch_tokenization(self):
     tokenizer = WordTokenizer()
     sentences = ["This is a sentence",
                  "This isn't a sentence.",
                  "This is the 3rd sentence."
                  "Here's the 'fourth' sentence."]
     batch_tokenized = tokenizer.batch_tokenize(sentences)
     separately_tokenized = [tokenizer.tokenize(sentence) for sentence in sentences]
     assert len(batch_tokenized) == len(separately_tokenized)
     for batch_sentence, separate_sentence in zip(batch_tokenized, separately_tokenized):
         assert len(batch_sentence) == len(separate_sentence)
         for batch_word, separate_word in zip(batch_sentence, separate_sentence):
             assert batch_word.text == separate_word.text
예제 #6
0
    def test_squad_with_unwordpieceable_passage(self):
        # pylint: disable=line-too-long
        tokenizer = WordTokenizer()

        token_indexer = PretrainedBertIndexer("bert-base-uncased")

        passage1 = ("There were four major HDTV systems tested by SMPTE in the late 1970s, "
                    "and in 1979 an SMPTE study group released A Study of High Definition Television Systems:")
        question1 = "Who released A Study of High Definition Television Systems?"

        passage2 = ("Broca, being what today would be called a neurosurgeon, "
                    "had taken an interest in the pathology of speech. He wanted "
                    "to localize the difference between man and the other animals, "
                    "which appeared to reside in speech. He discovered the speech "
                    "center of the human brain, today called Broca's area after him. "
                    "His interest was mainly in Biological anthropology, but a German "
                    "philosopher specializing in psychology, Theodor Waitz, took up the "
                    "theme of general and social anthropology in his six-volume work, "
                    "entitled Die Anthropologie der Naturvölker, 1859–1864. The title was "
                    """soon translated as "The Anthropology of Primitive Peoples". """
                    "The last two volumes were published posthumously.")
        question2 = "What did Broca discover in the human brain?"

        from allennlp.data.dataset_readers.reading_comprehension.util import make_reading_comprehension_instance

        instance1 = make_reading_comprehension_instance(tokenizer.tokenize(question1),
                                                        tokenizer.tokenize(passage1),
                                                        {"bert": token_indexer},
                                                        passage1)

        instance2 = make_reading_comprehension_instance(tokenizer.tokenize(question2),
                                                        tokenizer.tokenize(passage2),
                                                        {"bert": token_indexer},
                                                        passage2)

        vocab = Vocabulary()

        batch = Batch([instance1, instance2])
        batch.index_instances(vocab)

        padding_lengths = batch.get_padding_lengths()
        tensor_dict = batch.as_tensor_dict(padding_lengths)
        qtokens = tensor_dict["question"]
        ptokens = tensor_dict["passage"]

        config = BertConfig(len(token_indexer.vocab))
        model = BertModel(config)
        embedder = BertEmbedder(model)

        _ = embedder(ptokens["bert"], offsets=ptokens["bert-offsets"])
        _ = embedder(qtokens["bert"], offsets=qtokens["bert-offsets"])
예제 #7
0
    def test_end_to_end(self):
        tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter())

        #            2   3    4   3     5     6   8      9    2   14   12
        sentence1 = "the quickest quick brown fox jumped over the lazy dog"
        tokens1 = tokenizer.tokenize(sentence1)

        #            2   3     5     6   8      9    2  15 10 11 14   1
        sentence2 = "the quick brown fox jumped over the laziest lazy elmo"
        tokens2 = tokenizer.tokenize(sentence2)

        vocab = Vocabulary()

        instance1 = Instance({"tokens": TextField(tokens1, {"bert": self.token_indexer})})
        instance2 = Instance({"tokens": TextField(tokens2, {"bert": self.token_indexer})})

        batch = Batch([instance1, instance2])
        batch.index_instances(vocab)

        padding_lengths = batch.get_padding_lengths()
        tensor_dict = batch.as_tensor_dict(padding_lengths)
        tokens = tensor_dict["tokens"]

        assert tokens["bert"].tolist() == [
                [2, 3, 4, 3, 5, 6, 8, 9, 2, 14, 12, 0],
                [2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1]
        ]

        assert tokens["bert-offsets"].tolist() == [
                [0, 2, 3, 4, 5, 6, 7, 8, 9, 10],
                [0, 1, 2, 3, 4, 5, 6, 9, 10, 11]
        ]

        # No offsets, should get 12 vectors back.
        bert_vectors = self.token_embedder(tokens["bert"])
        assert list(bert_vectors.shape) == [2, 12, 12]

        # Offsets, should get 10 vectors back.
        bert_vectors = self.token_embedder(tokens["bert"], offsets=tokens["bert-offsets"])
        assert list(bert_vectors.shape) == [2, 10, 12]

        ## Now try top_layer_only = True
        tlo_embedder = BertEmbedder(self.bert_model, top_layer_only=True)
        bert_vectors = tlo_embedder(tokens["bert"])
        assert list(bert_vectors.shape) == [2, 12, 12]

        bert_vectors = tlo_embedder(tokens["bert"], offsets=tokens["bert-offsets"])
        assert list(bert_vectors.shape) == [2, 10, 12]
    def setUp(self):
        self.tokenizer = WordTokenizer(SpacyWordSplitter(pos_tags=True))
        self.utterance = self.tokenizer.tokenize("where is mersin?")
        self.token_indexers = {"tokens": SingleIdTokenIndexer("tokens")}

        json = {
                'question': self.utterance,
                'columns': ['Name in English', 'Location in English'],
                'cells': [['Paradeniz', 'Mersin'],
                          ['Lake Gala', 'Edirne']]
                }
        self.graph = TableQuestionKnowledgeGraph.read_from_json(json)
        self.vocab = Vocabulary()
        self.name_index = self.vocab.add_token_to_namespace("name", namespace='tokens')
        self.in_index = self.vocab.add_token_to_namespace("in", namespace='tokens')
        self.english_index = self.vocab.add_token_to_namespace("english", namespace='tokens')
        self.location_index = self.vocab.add_token_to_namespace("location", namespace='tokens')
        self.paradeniz_index = self.vocab.add_token_to_namespace("paradeniz", namespace='tokens')
        self.mersin_index = self.vocab.add_token_to_namespace("mersin", namespace='tokens')
        self.lake_index = self.vocab.add_token_to_namespace("lake", namespace='tokens')
        self.gala_index = self.vocab.add_token_to_namespace("gala", namespace='tokens')
        self.negative_one_index = self.vocab.add_token_to_namespace("-1", namespace='tokens')
        self.zero_index = self.vocab.add_token_to_namespace("0", namespace='tokens')
        self.one_index = self.vocab.add_token_to_namespace("1", namespace='tokens')

        self.oov_index = self.vocab.get_token_index('random OOV string', namespace='tokens')
        self.edirne_index = self.oov_index
        self.field = KnowledgeGraphField(self.graph, self.utterance, self.token_indexers, self.tokenizer)

        super(KnowledgeGraphFieldTest, self).setUp()
 def test_stems_and_filters_correctly(self):
     tokenizer = WordTokenizer.from_params(Params({'word_stemmer': {'type': 'porter'},
                                                   'word_filter': {'type': 'stopwords'}}))
     sentence = "this (sentence) has 'crazy' \"punctuation\"."
     expected_tokens = ["sentenc", "ha", "crazi", "punctuat"]
     tokens = [t.text for t in tokenizer.tokenize(sentence)]
     assert tokens == expected_tokens
예제 #10
0
 def test_char_span_to_token_span_handles_easy_cases(self):
     # These are _inclusive_ spans, on both sides.
     tokenizer = WordTokenizer()
     passage = "On January 7, 2012, Beyoncé gave birth to her first child, a daughter, Blue Ivy " +\
         "Carter, at Lenox Hill Hospital in New York. Five months later, she performed for four " +\
         "nights at Revel Atlantic City's Ovation Hall to celebrate the resort's opening, her " +\
         "first performances since giving birth to Blue Ivy."
     tokens = tokenizer.tokenize(passage)
     offsets = [(t.idx, t.idx + len(t.text)) for t in tokens]
     # "January 7, 2012"
     token_span = util.char_span_to_token_span(offsets, (3, 18))[0]
     assert token_span == (1, 4)
     # "Lenox Hill Hospital"
     token_span = util.char_span_to_token_span(offsets, (91, 110))[0]
     assert token_span == (22, 24)
     # "Lenox Hill Hospital in New York."
     token_span = util.char_span_to_token_span(offsets, (91, 123))[0]
     assert token_span == (22, 28)
예제 #11
0
    def test_max_length(self):
        config = BertConfig(len(self.token_indexer.vocab))
        model = BertModel(config)
        embedder = BertEmbedder(model)

        tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter())
        sentence = "the " * 1000
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()

        instance = Instance({"tokens": TextField(tokens, {"bert": self.token_indexer})})

        batch = Batch([instance])
        batch.index_instances(vocab)

        padding_lengths = batch.get_padding_lengths()
        tensor_dict = batch.as_tensor_dict(padding_lengths)
        tokens = tensor_dict["tokens"]
        embedder(tokens["bert"], tokens["bert-offsets"])
    def test_predicate_consolidation(self):
        """
        Test whether the predictor can correctly consolidate multiword
        predicates.
        """
        tokenizer = WordTokenizer(word_splitter=SpacyWordSplitter(pos_tags=True))

        sent_tokens = tokenizer.tokenize("In December, John decided to join the party.")

        # Emulate predications - for both "decided" and "join"
        predictions = [['B-ARG2', 'I-ARG2', 'O', 'B-ARG0', 'B-V', 'B-ARG1', 'I-ARG1', \
                        'I-ARG1', 'I-ARG1', 'O'],
                       ['O', 'O', 'O', 'B-ARG0', 'B-BV', 'I-BV', 'B-V', 'B-ARG1', \
                        'I-ARG1', 'O']]
        # Consolidate
        pred_dict = consolidate_predictions(predictions, sent_tokens)

        # Check that only "decided to join" is left
        assert len(pred_dict) == 1
        tags = list(pred_dict.values())[0]
        assert get_predicate_text(sent_tokens, tags) == "decided to join"
    def test_more_than_two_overlapping_predicates(self):
        """
        Test whether the predictor can correctly consolidate multiword
        predicates.
        """
        tokenizer = WordTokenizer(word_splitter=SpacyWordSplitter(pos_tags=True))

        sent_tokens = tokenizer.tokenize("John refused to consider joining the club.")

        # Emulate predications - for "refused" and "consider" and "joining"
        predictions = [['B-ARG0', 'B-V', 'B-ARG1', 'I-ARG1', 'I-ARG1', 'I-ARG1', 'I-ARG1', 'O'],\
                       ['B-ARG0', 'B-BV', 'I-BV', 'B-V', 'B-ARG1', 'I-ARG1', 'I-ARG1', 'O'],\
                       ['B-ARG0', 'B-BV', 'I-BV', 'I-BV', 'B-V', 'B-ARG1', 'I-ARG1', 'O']]

        # Consolidate
        pred_dict = consolidate_predictions(predictions, sent_tokens)

        # Check that only "refused to consider to join" is left
        assert len(pred_dict) == 1
        tags = list(pred_dict.values())[0]
        assert get_predicate_text(sent_tokens, tags) == "refused to consider joining"
예제 #14
0
    def test_starting_ending_offsets(self):
        tokenizer = WordTokenizer(word_splitter=BertBasicWordSplitter())

        #           2   3     5     6   8      9    2  15 10 11 14   1
        sentence = "the quick brown fox jumped over the laziest lazy elmo"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
        token_indexer = PretrainedBertIndexer(str(vocab_path))

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        assert indexed_tokens["bert"] == [2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1]
        assert indexed_tokens["bert-offsets"] == [0, 1, 2, 3, 4, 5, 6, 9, 10, 11]

        token_indexer = PretrainedBertIndexer(str(vocab_path), use_starting_offsets=True)

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        assert indexed_tokens["bert"] == [2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1]
        assert indexed_tokens["bert-offsets"] == [0, 1, 2, 3, 4, 5, 6, 7, 10, 11]
예제 #15
0
 def test_char_span_to_token_span_handles_hard_cases(self):
     # An earlier version of the code had a hard time when the answer was the last token in the
     # passage.  This tests that case, on the instance that used to fail.
     tokenizer = WordTokenizer()
     passage = "Beyonc\u00e9 is believed to have first started a relationship with Jay Z " +\
         "after a collaboration on \"'03 Bonnie & Clyde\", which appeared on his seventh " +\
         "album The Blueprint 2: The Gift & The Curse (2002). Beyonc\u00e9 appeared as Jay " +\
         "Z's girlfriend in the music video for the song, which would further fuel " +\
         "speculation of their relationship. On April 4, 2008, Beyonc\u00e9 and Jay Z were " +\
         "married without publicity. As of April 2014, the couple have sold a combined 300 " +\
         "million records together. The couple are known for their private relationship, " +\
         "although they have appeared to become more relaxed in recent years. Beyonc\u00e9 " +\
         "suffered a miscarriage in 2010 or 2011, describing it as \"the saddest thing\" " +\
         "she had ever endured. She returned to the studio and wrote music in order to cope " +\
         "with the loss. In April 2011, Beyonc\u00e9 and Jay Z traveled to Paris in order " +\
         "to shoot the album cover for her 4, and unexpectedly became pregnant in Paris."
     start = 912
     end = 912 + len("Paris.")
     tokens = tokenizer.tokenize(passage)
     offsets = [(t.idx, t.idx + len(t.text)) for  t in tokens]
     token_span = util.char_span_to_token_span(offsets, (start, end))[0]
     assert token_span == (184, 185)
예제 #16
0
def read(fn: str) -> List[Extraction]:
    tokenizer = WordTokenizer(word_splitter = SpacyWordSplitter(pos_tags=True))
    prev_sent = []

    with open(fn) as fin:
        for line in tqdm(fin):
            data = line.strip().split('\t')
            confidence = data[0]
            if not all(data[2:5]):
                # Make sure that all required elements are present
                continue
            arg1, rel, args2 = map(parse_element,
                                   data[2:5])

            # Exactly one subject and one relation
            # and at least one object
            if ((len(rel) == 1) and \
                (len(arg1) == 1) and \
                (len(args2) >= 1)):
                sent = data[5]
                cur_ex = Extraction(sent = sent,
                                    toks = tokenizer.tokenize(sent),
                                    arg1 = arg1[0],
                                    rel = rel[0],
                                    args2 = args2,
                                    confidence = confidence)


                # Decide whether to append or yield
                if (not prev_sent) or (prev_sent[0].sent == sent):
                    prev_sent.append(cur_ex)
                else:
                    yield prev_sent
                    prev_sent = [cur_ex]
    if prev_sent:
        # Yield last element
        yield prev_sent
예제 #17
0
    def test_do_lowercase(self):
        # Our default tokenizer doesn't handle lowercasing.
        tokenizer = WordTokenizer()

        # Quick is UNK because of capitalization
        #           2   1     5     6   8      9    2  15 10 11 14   1
        sentence = "the Quick brown fox jumped over the laziest lazy elmo"
        tokens = tokenizer.tokenize(sentence)

        vocab = Vocabulary()
        vocab_path = self.FIXTURES_ROOT / 'bert' / 'vocab.txt'
        token_indexer = PretrainedBertIndexer(str(vocab_path), do_lowercase=False)

        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        # Quick should get 1 == OOV
        assert indexed_tokens["bert"] == [16, 2, 1, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17]

        # Does lowercasing by default
        token_indexer = PretrainedBertIndexer(str(vocab_path))
        indexed_tokens = token_indexer.tokens_to_indices(tokens, vocab, "bert")

        # Now Quick should get indexed correctly as 3 ( == "quick")
        assert indexed_tokens["bert"] == [16, 2, 3, 5, 6, 8, 9, 2, 15, 10, 11, 14, 1, 17]
예제 #18
0
 def __init__(self,
              tokenizer: Tokenizer = None,
              token_indexers: Dict[str, TokenIndexer] = None,
              lazy: bool = False,
              passage_length_limit: int = None,
              question_length_limit: int = None,
              skip_when_all_empty: List[str] = None,
              instance_format: str = "drop",
              relaxed_span_match_for_finding_labels: bool = True) -> None:
     super().__init__(lazy)
     self._tokenizer = tokenizer or WordTokenizer()
     self._token_indexers = token_indexers or {
         "tokens": SingleIdTokenIndexer()
     }
     self.passage_length_limit = passage_length_limit
     self.question_length_limit = question_length_limit
     self.skip_when_all_empty = skip_when_all_empty if skip_when_all_empty is not None else []
     for item in self.skip_when_all_empty:
         assert item in ["passage_span", "question_span", "addition_subtraction", "counting"], \
             f"Unsupported skip type: {item}"
     self.instance_format = instance_format
     self.relaxed_span_match_for_finding_labels = relaxed_span_match_for_finding_labels
예제 #19
0
파일: drop.py 프로젝트: MyPaperCode/RAIN
 def __init__(self,
              tokenizer: Tokenizer = None,
              token_indexers: Dict[str, TokenIndexer] = None,
              lazy: bool = False,
              passage_length_limit: int = None,
              question_length_limit: int = None,
              skip_when_all_empty: List[str] = None,
              instance_format: str = "drop",
              bert_pretrain_model: str = None,
              implicit_number: List[int] = None,
              relaxed_span_match_for_finding_labels: bool = True) -> None:
     super().__init__(lazy)
     self._tokenizer = tokenizer or  WordTokenizer()
     self.bert_tokenizer = BertTokenizer.from_pretrained(bert_pretrain_model).wordpiece_tokenizer.tokenize
     
     self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
     self.passage_length_limit = passage_length_limit
     self.question_length_limit = question_length_limit
     self.skip_when_all_empty = skip_when_all_empty if skip_when_all_empty is not None else []
     self.instance_format = instance_format
     self.relaxed_span_match_for_finding_labels = relaxed_span_match_for_finding_labels
     self.implicit_number = implicit_number  
     self.implicit_tokens = [Token(str(number)) for number in self.implicit_number]
예제 #20
0
 def __init__(self,
              tokenizer: Tokenizer = None,
              source_token_indexers: Dict[str, TokenIndexer] = None,
              target_token_indexers: Dict[str, TokenIndexer] = None,
              source_max_tokens: int = 400,
              target_max_tokens: int = 100,
              separate_namespaces: bool = False,
              target_namespace: str = "target_tokens",
              save_copy_fields: bool = False,
              save_pgn_fields: bool = False) -> None:
     if not tokenizer:
         tokenizer = WordTokenizer(word_splitter=SimpleWordSplitter())
     super().__init__(
         tokenizer=tokenizer,
         source_token_indexers=source_token_indexers,
         target_token_indexers=target_token_indexers,
         source_max_tokens=source_max_tokens,
         target_max_tokens=target_max_tokens,
         separate_namespaces=separate_namespaces,
         target_namespace=target_namespace,
         save_copy_fields=save_copy_fields,
         save_pgn_fields=save_pgn_fields
     )
예제 #21
0
 def __init__(self,
              target_namespace: str,
              source_tokenizer: Tokenizer = None,
              target_tokenizer: Tokenizer = None,
              source_token_indexers: Dict[str, TokenIndexer] = None,
              lazy: bool = False) -> None:
     super().__init__(lazy)
     archive = load_archive('./temp/bidaf_baseline/model.tar.gz')
     self.predictor = Predictor.from_archive(archive, 'sharc_predictor')
     self._target_namespace = target_namespace
     self._source_tokenizer = source_tokenizer or WordTokenizer()
     self._target_tokenizer = target_tokenizer or self._source_tokenizer
     self._source_token_indexers = source_token_indexers or {
         "tokens": SingleIdTokenIndexer()
     }
     if "tokens" not in self._source_token_indexers or \
             not isinstance(self._source_token_indexers["tokens"], SingleIdTokenIndexer):
         raise ConfigurationError(
             "CopyNetDatasetReader expects 'source_token_indexers' to contain "
             "a 'single_id' token indexer called 'tokens'.")
     self._target_token_indexers: Dict[str, TokenIndexer] = {
         "tokens": SingleIdTokenIndexer(namespace=self._target_namespace)
     }
    def __init__(self,
                 tokens_per_instance: int = None,
                 tokenizer: Tokenizer = None,
                 token_indexers: Dict[str, TokenIndexer] = None,
                 lazy: bool = False) -> None:
        super().__init__(lazy)
        self._tokenizer = tokenizer or WordTokenizer()
        self._token_indexers = token_indexers or {
            "tokens": SingleIdTokenIndexer()
        }
        self._tokens_per_instance = tokens_per_instance

        # No matter how you want to represent the input, we'll always represent the output as a
        # single token id.  This code lets you learn a language model that concatenates word
        # embeddings with character-level encoders, in order to predict the word token that comes
        # next.
        self._output_indexer: Dict[str, TokenIndexer] = None
        for name, indexer in self._token_indexers.items():
            if isinstance(indexer, SingleIdTokenIndexer):
                self._output_indexer = {name: indexer}
                break
        else:
            self._output_indexer = {"tokens": SingleIdTokenIndexer()}
예제 #23
0
 def __init__(
     self,
     target_namespace: str,
     source_tokenizer: Tokenizer = None,
     target_tokenizer: Tokenizer = None,
     source_token_indexers: Dict[str, TokenIndexer] = None,
     lazy: bool = False,
 ) -> None:
     super().__init__(lazy)
     self._target_namespace = target_namespace
     self._source_tokenizer = source_tokenizer or WordTokenizer()
     self._target_tokenizer = target_tokenizer or self._source_tokenizer
     self._source_token_indexers = source_token_indexers or {
         "tokens": SingleIdTokenIndexer()
     }
     if "tokens" not in self._source_token_indexers or not isinstance(
             self._source_token_indexers["tokens"], SingleIdTokenIndexer):
         raise ConfigurationError(
             "CopyNetDatasetReader expects 'source_token_indexers' to contain "
             "a 'single_id' token indexer called 'tokens'.")
     self._target_token_indexers: Dict[str, TokenIndexer] = {
         "tokens": SingleIdTokenIndexer(namespace=self._target_namespace)
     }
class DropWorldTest(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        self.tokenizer = WordTokenizer()
        self.tokens = self.tokenizer.tokenize(
            """how many points did the redskins score in the final
                                              two minutes of the game?""")
        context = ParagraphQuestionContext.read_from_file(
            "fixtures/data/tables/sample_paragraph.tagged", self.tokens)
        self.world = DropWorld(context)

    def test_get_agenda(self):
        assert self.world.get_agenda() == [
            '<p,n> -> count_structures', 's -> string:point',
            's -> string:redskin', 's -> string:score', 's -> string:two',
            's -> string:game'
        ]

    def test_world_with_empty_paragraph(self):
        context = ParagraphQuestionContext.read_from_file(
            "fixtures/data/tables/empty_paragraph.tagged", self.tokens)
        # We're just confirming that creating a world wit empty context does not throw an error.
        DropWorld(context)
예제 #25
0
class MyReader(DatasetReader):
    """
    Just reads in a text file and sticks each line
    in a ``TextField`` with the specified name.
    """
    def __init__(self, field_name: str) -> None:
        super().__init__()
        self.field_name = field_name
        self.tokenizer = WordTokenizer()
        self.token_indexers: Dict[str, TokenIndexer] = {
            "tokens": SingleIdTokenIndexer()
        }

    def text_to_instance(self, sentence: str) -> Instance:  # type: ignore
        # pylint: disable=arguments-differ
        tokens = self.tokenizer.tokenize(sentence)
        return Instance(
            {self.field_name: TextField(tokens, self.token_indexers)})

    def _read(self, file_path: str):
        with open(file_path) as data_file:
            for line in data_file:
                yield self.text_to_instance(line)
예제 #26
0
 def __init__(self,
              source_tokenizer: Tokenizer = None,
              target_tokenizer: Tokenizer = None,
              task_token_indexers: Dict[str, TokenIndexer] = None,
              domain_token_indexers: Dict[str, TokenIndexer] = None,
              source_token_indexers: Dict[str, TokenIndexer] = None,
              target_token_indexers: Dict[str, TokenIndexer] = None,
              source_add_start_token: bool = True,
              lazy: bool = False) -> None:
     super().__init__(lazy)
     self._source_tokenizer = source_tokenizer or WordTokenizer()
     self._target_tokenizer = target_tokenizer or self._source_tokenizer
     self._task_token_indexers = task_token_indexers or {
         "task_token": SingleIdTokenIndexer()
     }
     self._domain_token_indexers = domain_token_indexers or {
         "domain_token": SingleIdTokenIndexer()
     }
     self._source_token_indexers = source_token_indexers or {
         "tokens": SingleIdTokenIndexer()
     }
     self._target_token_indexers = target_token_indexers or self._source_token_indexers
     self._source_add_start_token = source_add_start_token
예제 #27
0
 def __init__(self,
              tokenizer: Tokenizer = None,
              token_indexers: Dict[str, TokenIndexer] = None,
              para_limit: int = 2250,
              sent_limit: int = 75,
              word_piece_limit: int = 142,
              context_limit: int = 20,
              training: bool = False,
              filter_compare_q: bool = False,
              chain: str = 'rb',
              lazy: bool = False) -> None:
     super().__init__(lazy)
     self._tokenizer = tokenizer or WordTokenizer()
     self._token_indexers = token_indexers or {
         'tokens': SingleIdTokenIndexer()
     }
     self._para_limit = para_limit
     self._sent_limit = sent_limit
     self._context_limit = context_limit
     self._word_piece_limit = word_piece_limit
     self._filter_compare_q = filter_compare_q
     self.chain = chain
     self.training = training
예제 #28
0
 def __init__(self,
              source_tokenizer: Tokenizer = None,
              target_tokenizer: Tokenizer = None,
              source_token_indexers: Dict[str, TokenIndexer] = None,
              target_token_indexers: Dict[str, TokenIndexer] = None,
              source_add_start_token: bool = True,
              delimiter: str = "\t",
              source_max_tokens: Optional[int] = None,
              target_max_tokens: Optional[int] = None,
              lazy: bool = False) -> None:
     super().__init__(lazy)
     self._source_tokenizer = source_tokenizer or WordTokenizer()
     self._target_tokenizer = target_tokenizer or self._source_tokenizer
     self._source_token_indexers = source_token_indexers or {
         "tokens": SingleIdTokenIndexer()
     }
     self._target_token_indexers = target_token_indexers or self._source_token_indexers
     self._source_add_start_token = source_add_start_token
     self._delimiter = delimiter
     self._source_max_tokens = source_max_tokens
     self._target_max_tokens = target_max_tokens
     self._source_max_exceeded = 0
     self._target_max_exceeded = 0
예제 #29
0
 def __init__(self,
              token_indexers: Dict[str, TokenIndexer] = None,
              tokenizer: Tokenizer = None,
              max_sequence_length: int = None,
              ignore_labels: bool = False,
              sample: int = None,
              skip_label_indexing: bool = False,
              lazy: bool = False) -> None:
     super().__init__(lazy=lazy,
                      token_indexers=token_indexers,
                      tokenizer=tokenizer,
                      max_sequence_length=max_sequence_length,
                      skip_label_indexing=skip_label_indexing)
     self._tokenizer = tokenizer or WordTokenizer()
     self._sample = sample
     self._max_sequence_length = max_sequence_length
     self._ignore_labels = ignore_labels
     self._skip_label_indexing = skip_label_indexing
     self._token_indexers = token_indexers or {
         'tokens': SingleIdTokenIndexer()
     }
     if self._segment_sentences:
         self._sentence_segmenter = SpacySentenceSplitter()
 def __init__(self,
              max_source_length: int = 400,
              max_target_length: int = 100,
              tokenizer: Tokenizer = None,
              token_indexers: Dict[str, TokenIndexer] = None,
              lowercase_tokens: bool = False,
              lazy: bool = True,
              max_to_read=np.inf) -> None:
     super().__init__(lazy)
     self.lowercase_tokens = lowercase_tokens
     self.max_source_length = max_source_length
     self.max_target_length = max_target_length
     self.max_to_read = max_to_read
     self._tokenizer = tokenizer or WordTokenizer(
         word_splitter=JustSpacesWordSplitter())
     self._token_indexers = token_indexers or {
         "tokens": SingleIdTokenIndexer()
     }
     if "tokens" not in self._token_indexers or \
             not isinstance(self._token_indexers["tokens"], SingleIdTokenIndexer):
         raise ConfigurationError(
             "CNNDmailDatasetReader expects 'token_indexers' to contain "
             "a 'single_id' token indexer called 'tokens'.")
예제 #31
0
파일: kaldi.py 프로젝트: Chung-I/tsm-rnnt
    def __init__(self,
                 shard_size: int,
                 lexicon_path: str,
                 transcript_path: str,
                 input_stack_rate: int = 1,
                 model_stack_rate: int = 1,
                 target_tokenizer: Tokenizer = None,
                 target_token_indexers: Dict[str, TokenIndexer] = None,
                 target_add_start_end_token: bool = False,
                 delimiter: str = "\t",
                 lazy: bool = False) -> None:
        super().__init__(lazy)
        transcript_files = glob.glob(transcript_path)
        self.transcripts: Dict[str, str] = {}
        for transcript_file in transcript_files:
            with open(transcript_file) as f:
                for line in f.read().splitlines():
                    end, start = re.search(r'\s+', line).span()
                    self.transcripts[line[:end]] = line[start:]

        self.lexicon: Dict[str, str] = {}
        with open(lexicon_path) as f:
            for line in f.read().splitlines():
                end, start = re.search(r'\s+', line).span()
                self.lexicon[line[:end]] = line[start:]

        self.cc = OpenCC('s2t')
        self.w2p = word_to_phones(self.lexicon)
        self._target_tokenizer = target_tokenizer or WordTokenizer()
        self._target_token_indexers = target_token_indexers or {
            "tokens": SingleIdTokenIndexer()
        }
        self._delimiter = delimiter
        self._shard_size = shard_size
        self.input_stack_rate = input_stack_rate
        self.model_stack_rate = model_stack_rate
        self._target_add_start_end_token = target_add_start_end_token
예제 #32
0
    def __init__(self,
                 lazy: bool = False,
                 tokenizer: Tokenizer = None,
                 incl_target: bool = True,
                 reverse_right_text: bool = True,
                 token_indexers: Dict[str, TokenIndexer] = None,
                 sentiment_mapper: Dict[int, str] = None):
        '''
        This dataset reader can also be used in conjunction with the augmented 
        iterator.

        :param incl_target: Whether to include the target word(s) in the left 
                            and right contexts. By default this is True as 
                            this is what the original TDLSTM method specified.
        :param reverse_right_text: If the text that can include the target and 
                                   all text right of the target should be 
                                   returned tokenised in reverse order starting 
                                   from the right most token to the left 
                                   most token which would be the first token 
                                   of the target if the target is included. 
                                   This is required to reproduce the single 
                                   layer LSTM method of TDLSTM, if a 
                                   bi-directional LSTM encoder is chosen to 
                                   encode the right text then this parameter 
                                   does not matter and would be quicker to 
                                   choose False.
        :param sentiment_mapper: If not given maps -1, 0, 1 labels to `negative`
                                 , `neutral`, and `positive` respectively.
        '''
        super().__init__(lazy)
        self._tokenizer = tokenizer or WordTokenizer()
        self.incl_target = incl_target
        self.reverse_right_text = reverse_right_text
        self._token_indexers = token_indexers or \
                               {"tokens": SingleIdTokenIndexer()}
        self.sentiment_mapper = sentiment_mapper or \
                                {-1: 'negative', 0: 'neutral', 1: 'positive'}
예제 #33
0
    def __init__(self,
                 negative_sentence_selection: str = "paragraph",
                 tokenizer: Tokenizer = None,
                 token_indexers: Dict[str, TokenIndexer] = None) -> None:
        self._tokenizer = tokenizer or WordTokenizer()
        self._token_indexers = token_indexers or {
            'tokens': SingleIdTokenIndexer()
        }
        self._negative_sentence_selection_methods = negative_sentence_selection.split(
            ",")

        # Initializing some data structures here that will be useful when reading a file.
        # Maps sentence strings to sentence indices
        self._sentence_to_id: Dict[str, int] = {}
        # Maps sentence indices to sentence strings
        self._id_to_sentence: Dict[int, str] = {}
        # Maps paragraph ids to lists of contained sentence ids
        self._paragraph_sentences: Dict[int, List[int]] = {}
        # Maps sentence ids to the containing paragraph id.
        self._sentence_paragraph_map: Dict[int, int] = {}
        # Maps question strings to question indices
        self._question_to_id: Dict[str, int] = {}
        # Maps question indices to question strings
        self._id_to_question: Dict[int, str] = {}
예제 #34
0
def test_iterator():
    indexer = StaticFasttextTokenIndexer(
        model_path="./data/fasttext_embedding.model",
        model_params_path="./data/fasttext_embedding.model.params")

    loader = MenionsLoader(
        category_mapping_file='./data/test_category_mapping.json',
        token_indexers={"tokens": indexer},
        tokenizer=WordTokenizer(word_splitter=FastSplitter()))

    vocab = Vocabulary.from_params(Params({"directory_path":
                                           "./data/vocab2/"}))

    iterator = BasicIterator(batch_size=32)

    iterator.index_with(vocab)

    limit = 50
    for _ in tqdm.tqdm(iterator(loader.read('./data/train_data_aa.tsv'),
                                num_epochs=1),
                       mininterval=2):
        limit -= 1
        if limit <= 0:
            break
예제 #35
0
    def __init__(self,
                 target_namespace: str,
                 span_predictor_model,
                 source_tokenizer: Tokenizer = None,
                 target_tokenizer: Tokenizer = None,
                 source_token_indexers: Dict[str, TokenIndexer] = None,
                 lazy: bool = False,
                 add_rule=True,
                 embed_span=True,
                 add_question=True,
                 add_followup_ques=True) -> None:
        super().__init__(lazy)
        self._target_namespace = target_namespace
        self._source_tokenizer = source_tokenizer or WordTokenizer()
        self._target_tokenizer = target_tokenizer or self._source_tokenizer
        self._source_token_indexers = source_token_indexers or {
            "tokens": SingleIdTokenIndexer()
        }
        self.add_rule = add_rule
        self.embed_span = embed_span
        self.add_question = add_question
        self.add_followup_ques = add_followup_ques
        if "tokens" not in self._source_token_indexers or \
                not isinstance(self._source_token_indexers["tokens"], SingleIdTokenIndexer):
            raise ConfigurationError(
                "CopyNetDatasetReader expects 'source_token_indexers' to contain "
                "a 'single_id' token indexer called 'tokens'.")
        self._target_token_indexers: Dict[str, TokenIndexer] = {
            "tokens": SingleIdTokenIndexer(namespace=self._target_namespace)
        }

        archive = load_archive(span_predictor_model)
        self.dataset_reader = DatasetReader.from_params(
            archive.config.duplicate()["dataset_reader"])
        self.span_predictor = Predictor.from_archive(archive,
                                                     'sharc_predictor')
예제 #36
0
 def __init__(self,
              lazy=False,
              tables_directory=None,
              dpd_output_directory=None,
              max_dpd_logical_forms=10,
              sort_dpd_logical_forms=True,
              max_dpd_tries=20,
              keep_if_no_dpd=False,
              tokenizer=None,
              question_token_indexers=None,
              table_token_indexers=None,
              use_table_for_vocab=False,
              linking_feature_extractors=None,
              include_table_metadata=False,
              max_table_tokens=None,
              output_agendas=False):
     super(WikiTablesDatasetReader, self).__init__(lazy=lazy)
     self._tables_directory = tables_directory
     self._dpd_output_directory = dpd_output_directory
     self._max_dpd_logical_forms = max_dpd_logical_forms
     self._sort_dpd_logical_forms = sort_dpd_logical_forms
     self._max_dpd_tries = max_dpd_tries
     self._keep_if_no_dpd = keep_if_no_dpd
     self._tokenizer = tokenizer or WordTokenizer(
         SpacyWordSplitter(pos_tags=True))
     self._question_token_indexers = question_token_indexers or {
         u"tokens": SingleIdTokenIndexer()
     }
     self._table_token_indexers = table_token_indexers or self._question_token_indexers
     self._use_table_for_vocab = use_table_for_vocab
     self._linking_feature_extractors = linking_feature_extractors
     self._include_table_metadata = include_table_metadata
     self._basic_types = set(
         unicode(type_) for type_ in wt_types.BASIC_TYPES)
     self._max_table_tokens = max_table_tokens
     self._output_agendas = output_agendas
예제 #37
0
    def __init__(self,
                 skip_empty: bool = False,
                 downsample_negative: float = 0.05,
                 downsample_all: float = 0.1,
                 simplified: bool = True,
                 skip_toplevel_answer_candidates: bool = True,
                 maxlen: int = 450,
                 classes_to_ignore: List[str] = None,
                 token_indexers: Dict[str, TokenIndexer] = None,
                 lazy: bool = True):
        if not simplified:
            raise ConfigurationError(
                'Only simplified version of natural questions is allowed')
        super(NaturalQuestionsDatasetReader, self).__init__(lazy=lazy)
        self._tokenizer = WordTokenizer(word_splitter=JustSpacesWordSplitter())
        self._token_indexers = token_indexers or {
            'tokens': SingleIdTokenIndexer()
        }
        self._skip_empty = skip_empty
        self._maxlen = maxlen

        self._downsample_negative = downsample_negative
        self._skip_toplevel_answer_candidates = skip_toplevel_answer_candidates
        self._downsample_all = downsample_all
예제 #38
0
 def __init__(self,
              tokenizer: Tokenizer = None,
              token_indexers: Dict[str, TokenIndexer] = None,
              delimiter: str = ',',
              testing: bool = False,
              max_sequence_length: int = None,
              lazy: bool = False) -> None:
     """
     文本分类任务的datasetreader,从csv获取数据,head指定text,label.如:
     label   text
     sad    i like it.
     :param tokenizer: 分词器
     :param token_indexers:
     :param delimiter:
     :param testing:
     :param max_sequence_length:
     :param lazy:
     """
     super().__init__(lazy)
     self._tokenizer = tokenizer or WordTokenizer()
     self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
     self._delimiter = delimiter
     self.testing = testing
     self._max_sequence_length = max_sequence_length
예제 #39
0
    def __init__(self,
                 tokenizer: Callable[[str], List[str]] = None,
                 token_indexers: Dict[str, TokenIndexer] = None,
                 target_token_indexers: Dict[str, TokenIndexer] = None,
                 predicates: List[str] = None,
                 ontology_types: List[str] = None):
        super().__init__(lazy=False)

        self.tokenizer = tokenizer or WordTokenizer()
        self.token_indexers = token_indexers or {
            "tokens": SingleIdTokenIndexer()
        }
        self.target_token_indexers = target_token_indexers or {
            "tokens": SingleIdTokenIndexer(namespace='target_tokens')
        }

        self.predicates = [deurify_predicate(p) for p in predicates]
        self.original_predicates = predicates
        self.unique_predicates = list(set(self.predicates))
        self.ontology_types = ontology_types
        self.executor = StubExecutor()
        context = LCQuADContext(self.executor, [], ['ENT_1', 'ENT_2'],
                                self.unique_predicates)
        self.language = LCQuADLanguage(context)
예제 #40
0
 def __init__(self,
              source_tokenizer: Tokenizer = None,
              target_tokenizer: Tokenizer = None,
              source_token_indexers: Dict[str, TokenIndexer] = None,
              upos_token_indexers: Dict[str, TokenIndexer] = None,
              ner_token_indexers: Dict[str, TokenIndexer] = None,
              chunk_token_indexers: Dict[str, TokenIndexer] = None,
              source_add_start_token: bool = True,
              lazy: bool = False) -> None:
     super().__init__(lazy)
     self._source_tokenizer = source_tokenizer or WordTokenizer()
     self._target_tokenizer = target_tokenizer or self._source_tokenizer
     self._source_token_indexers = source_token_indexers or {
         "tokens": SingleIdTokenIndexer()
     }
     self._upos_token_indexers = upos_token_indexers or self._source_token_indexers
     self._ner_token_indexers = ner_token_indexers or self._source_token_indexers
     self._chunk_token_indexers = chunk_token_indexers or self._source_token_indexers
     self._task_to_indexers = {
         'upos': self._upos_token_indexers,
         'ner': self._ner_token_indexers,
         'chunk': self._chunk_token_indexers
     }
     self._source_add_start_token = source_add_start_token
    def __init__(
        self,
        tokenizer: Tokenizer = None,
        source_token_indexers: Dict[str, TokenIndexer] = None,
        target_token_indexers: Dict[str, TokenIndexer] = None,
        source_max_tokens: int = 400,
        target_max_tokens: int = 100,
        separate_namespaces: bool = False,
        target_namespace: str = "target_tokens",
        save_copy_fields: bool = False,
        save_pgn_fields: bool = False,
    ) -> None:
        super().__init__(lazy=True)

        assert (save_pgn_fields or save_copy_fields
                or (not save_pgn_fields and not save_copy_fields))

        self._source_max_tokens = source_max_tokens
        self._target_max_tokens = target_max_tokens

        self._tokenizer = tokenizer or WordTokenizer(
            word_splitter=SimpleWordSplitter())

        tokens_indexer = {"tokens": SingleIdTokenIndexer()}
        self._source_token_indexers = source_token_indexers or tokens_indexer
        self._target_token_indexers = target_token_indexers or tokens_indexer

        self._save_copy_fields = save_copy_fields
        self._save_pgn_fields = save_pgn_fields
        self._target_namespace = "tokens"
        if separate_namespaces:
            self._target_namespace = target_namespace
            second_tokens_indexer = {
                "tokens": SingleIdTokenIndexer(namespace=target_namespace)
            }
            self._target_token_indexers = target_token_indexers or second_tokens_indexer
예제 #42
0
    def __init__(
        self,
        tokenizer: Tokenizer = None,
        source_token_indexers: Dict[str, TokenIndexer] = None,
        target_token_indexers: Dict[str, TokenIndexer] = None,
        source_max_tokens : int = 400,
        target_max_tokens : int = 100,
        separate_namespaces: bool = False, # for what?
        target_namespace: str = 'target_tokens',
        save_copy_fields: bool = False,
        save_pgn_fields: bool = False,
        lazy: bool = False
        ) -> None:
        
        super().__init__(lazy)

        assert save_pgn_fields or save_copy_fields or (not save_copy_fields and not save_pgn_fields)

        self.source_max_tokens = source_max_tokens
        self.target_max_tokens = target_max_tokens

        self.tokenizer = tokenizer or WordTokenizer(word_splitter=SimpleWordSplitter())

        tokens_indexer = {'tokens':SingleIdTokenIndexer()}

        self.source_token_indexers = source_token_indexers or tokens_indexer
        self.target_token_indexers = target_token_indexers or tokens_indexer

        self.save_copy_fields = save_copy_fields
        self.save_pgn_fields = save_pgn_fields

        self.target_namespace = 'tokens'
        if separate_namespaces:
            self.target_namespace = target_namespace
            second_tokens_indexer = {'tokens':SingleIdTokenIndexer(namespace=target_namespace)}
            self.target_token_indexers = target_token_indexers or second_tokens_indexer
예제 #43
0
    def __init__(
            self,
            lazy: bool = False,
            sample: int = -1,
            lf_syntax: str = None,
            replace_world_entities: bool = False,
            align_world_extractions: bool = False,
            gold_world_extractions: bool = False,
            tagger_only: bool = False,
            denotation_only: bool = False,
            world_extraction_model: Optional[str] = None,
            skip_attributes_regex: Optional[str] = None,
            entity_bits_mode: Optional[str] = None,
            entity_types: Optional[List[str]] = None,
            lexical_cues: List[str] = None,
            tokenizer: Tokenizer = None,
            question_token_indexers: Dict[str, TokenIndexer] = None) -> None:
        super().__init__(lazy=lazy)
        self._tokenizer = tokenizer or WordTokenizer()
        self._question_token_indexers = question_token_indexers or {
            "tokens": SingleIdTokenIndexer()
        }
        self._entity_token_indexers = self._question_token_indexers
        self._sample = sample
        self._replace_world_entities = replace_world_entities
        self._lf_syntax = lf_syntax
        self._entity_bits_mode = entity_bits_mode
        self._align_world_extractions = align_world_extractions
        self._gold_world_extractions = gold_world_extractions
        self._entity_types = entity_types
        self._tagger_only = tagger_only
        self._denotation_only = denotation_only
        self._skip_attributes_regex = None
        if skip_attributes_regex is not None:
            self._skip_attributes_regex = re.compile(skip_attributes_regex)
        self._lexical_cues = lexical_cues

        # Recording of entities in categories relevant for tagging
        all_entities = {}
        all_entities["world"] = ["world1", "world2"]
        # TODO: Clarify this into an appropriate parameter
        self._collapse_tags = ["world"]

        self._all_entities = None
        if entity_types is not None:
            if self._entity_bits_mode == "collapsed":
                self._all_entities = entity_types
            else:
                self._all_entities = [
                    e for t in entity_types for e in all_entities[t]
                ]

        logger.info(f"all_entities = {self._all_entities}")

        # Base world, depending on LF syntax only
        self._knowledge_graph = KnowledgeGraph(
            entities={"placeholder"},
            neighbors={},
            entity_text={"placeholder": "placeholder"})
        self._world = QuarelWorld(self._knowledge_graph, self._lf_syntax)

        # Decide dynamic entities, if any
        self._dynamic_entities: Dict[str, str] = dict()
        self._use_attr_entities = False
        if "_attr_entities" in lf_syntax:
            self._use_attr_entities = True
            qr_coeff_sets = self._world.qr_coeff_sets
            for qset in qr_coeff_sets:
                for attribute in qset:
                    if (self._skip_attributes_regex is not None
                            and self._skip_attributes_regex.search(attribute)):
                        continue
                    # Get text associated with each entity, both from entity identifier and
                    # associated lexical cues, if any
                    entity_strings = [
                        words_from_entity_string(attribute).lower()
                    ]
                    if self._lexical_cues is not None:
                        for key in self._lexical_cues:
                            if attribute in LEXICAL_CUES[key]:
                                entity_strings += LEXICAL_CUES[key][attribute]
                    self._dynamic_entities["a:" + attribute] = " ".join(
                        entity_strings)

        # Update world to include dynamic entities
        if self._use_attr_entities:
            logger.info(f"dynamic_entities = {self._dynamic_entities}")
            neighbors: Dict[str, List[str]] = {
                key: []
                for key in self._dynamic_entities
            }
            self._knowledge_graph = KnowledgeGraph(
                entities=set(self._dynamic_entities.keys()),
                neighbors=neighbors,
                entity_text=self._dynamic_entities)
            self._world = QuarelWorld(self._knowledge_graph, self._lf_syntax)

        self._stemmer = PorterStemmer().stemmer

        self._world_tagger_extractor = None
        self._extract_worlds = False
        if world_extraction_model is not None:
            logger.info("Loading world tagger model...")
            self._extract_worlds = True
            self._world_tagger_extractor = WorldTaggerExtractor(
                world_extraction_model)
            logger.info("Done loading world tagger model!")

        # Convenience regex for recognizing attributes
        self._attr_regex = re.compile(r"""\((\w+) (high|low|higher|lower)""")
예제 #44
0
 def __init__(self,
              tokenizer: Tokenizer = None,
              token_indexers: Dict[str, TokenIndexer] = None,
              lazy: bool = False,
              max_pieces: int = 512,
              max_count: int = 10,
              max_spans: int = 10,
              max_numbers_expression: int = 2,
              answer_type: List[str] = None,
              use_validated: bool = True,
              wordpiece_numbers: bool = True,
              number_tokenizer: Tokenizer = None,
              custom_word_to_num: bool = True,
              exp_search: str = 'add_sub',
              max_depth: int = 3,
              extra_numbers: List[float] = [],
              question_type: List[str] = None,
              extract_spans: bool = False,
              spans_labels: List[str] = [],
              span_max_length: int = -1):
     super(BertDropReader, self).__init__(lazy)
     self.tokenizer = tokenizer
     self.token_indexers = token_indexers
     self.max_pieces = max_pieces
     self.max_count = max_count
     self.max_spans = max_spans
     self.max_numbers_expression = max_numbers_expression
     self.answer_type = answer_type
     self.use_validated = use_validated
     self.wordpiece_numbers = wordpiece_numbers
     self.number_tokenizer = number_tokenizer or WordTokenizer()
     self.exp_search = exp_search
     self.max_depth = max_depth
     self.extra_numbers = extra_numbers
     self.question_type = question_type
     self.extract_spans = extract_spans
     if self.extract_spans:
         self.span_extractor = SpanExtractor()
         self.spans_labels = spans_labels
         self.span_max_length = span_max_length
     self.op_dict = {
         '+': operator.add,
         '-': operator.sub,
         '*': operator.mul,
         '/': operator.truediv
     }
     self.operations = list(enumerate(self.op_dict.keys()))
     self.templates = [
         lambda x, y, z: (x + y) * z, lambda x, y, z: (x - y) * z,
         lambda x, y, z: (x + y) / z, lambda x, y, z: (x - y) / z,
         lambda x, y, z: x * y / z
     ]
     self.template_strings = [
         '(%s + %s) * %s',
         '(%s - %s) * %s',
         '(%s + %s) / %s',
         '(%s - %s) / %s',
         '%s * %s / %s',
     ]
     if custom_word_to_num:
         self.word_to_num = get_number_from_word
     else:
         self.word_to_num = DropReader.convert_word_to_number
예제 #45
0
    def __init__(self,
                 db: str,
                 sentence_level = False,                 
                 wiki_tokenizer: Tokenizer = None,
                 claim_tokenizer: Tokenizer = None,
                 token_indexers: Dict[str, TokenIndexer] = None,
                 include_evidence = False,
                 evidence_indices = False,
                 list_field = False,
                 split_evidence_groups = False,
                 include_features = False,
                 include_metadata = False,
                 label_lookup = None,
                 choose_min_evidence=False,
                 lazy: bool = True,
                 batch_size: int = 100,
                 bert_extractor_settings=None,
                 evidence_memory_size=50,
                 max_selected_evidence=5,
                 sentence_ranker_settings=None,
                 prepend_title=True,
                 bert_batch_mode=False,
                 cached_features_size=0,
                 titles_only=False,
                 cuda_device=-1) -> None:

        assert(cached_features_size == 0 or cached_features_size % batch_size == 0)
        
        super().__init__(lazy)

        self._sentence_level = sentence_level
        self._wiki_tokenizer = wiki_tokenizer or WordTokenizer()
        self._claim_tokenizer = claim_tokenizer or WordTokenizer()
        self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}

        self.include_evidence = include_evidence
        self.evidence_indices = evidence_indices        
        self.list_field = list_field
        self.split_evidence_groups = split_evidence_groups

        self.include_features = include_features
        self.include_metadata = include_metadata        
        
        self.label_lookup = label_lookup
        if label_lookup is None:
            self.label_lookup = {'NOT ENOUGH INFO': 0,
                                 'REFUTES': 1,
                                 'SUPPORTS': 2}

        self._choose_min_evidence = choose_min_evidence
        self.db = BatchedDB(db)

        self.sentence_ranker = None
        if sentence_ranker_settings is not None:
            nlp = spacy.load('en')
            self.tokenizer = English().Defaults.create_tokenizer(nlp)
            self.sentence_ranker = SimpleSentenceRanker(**sentence_ranker_settings)
        
        self.bert_feature_extractor = None
        self.bert_batch_mode = False
        if bert_extractor_settings is not None:
            bert_extractor_settings['cuda_device'] = cuda_device
            self.bert_feature_extractor = BertFeatureExtractor(**bert_extractor_settings,
                                                               label_map=self.label_lookup)
            self.bert_batch_mode = bert_batch_mode
            
        self.batch_size = batch_size
        self.evidence_memory_size = evidence_memory_size
        self.max_selected_evidence = max_selected_evidence
        self._prepend_title = prepend_title

        self._read = None
        self._features_cache = collections.defaultdict(dict)
        self._cached_features_size = cached_features_size
        
        self._titles_only = titles_only
class KnowledgeGraphFieldTest(AllenNlpTestCase):
    def setUp(self):
        self.tokenizer = WordTokenizer(SpacyWordSplitter(pos_tags=True))
        self.utterance = self.tokenizer.tokenize("where is mersin?")
        self.token_indexers = {"tokens": SingleIdTokenIndexer("tokens")}

        json = {
            'question': self.utterance,
            'columns': ['Name in English', 'Location in English'],
            'cells': [['Paradeniz', 'Mersin'], ['Lake Gala', 'Edirne']]
        }
        self.graph = TableQuestionKnowledgeGraph.read_from_json(json)
        self.vocab = Vocabulary()
        self.name_index = self.vocab.add_token_to_namespace("name",
                                                            namespace='tokens')
        self.in_index = self.vocab.add_token_to_namespace("in",
                                                          namespace='tokens')
        self.english_index = self.vocab.add_token_to_namespace(
            "english", namespace='tokens')
        self.location_index = self.vocab.add_token_to_namespace(
            "location", namespace='tokens')
        self.paradeniz_index = self.vocab.add_token_to_namespace(
            "paradeniz", namespace='tokens')
        self.mersin_index = self.vocab.add_token_to_namespace(
            "mersin", namespace='tokens')
        self.lake_index = self.vocab.add_token_to_namespace("lake",
                                                            namespace='tokens')
        self.gala_index = self.vocab.add_token_to_namespace("gala",
                                                            namespace='tokens')
        self.negative_one_index = self.vocab.add_token_to_namespace(
            "-1", namespace='tokens')
        self.zero_index = self.vocab.add_token_to_namespace("0",
                                                            namespace='tokens')
        self.one_index = self.vocab.add_token_to_namespace("1",
                                                           namespace='tokens')

        self.oov_index = self.vocab.get_token_index('random OOV string',
                                                    namespace='tokens')
        self.edirne_index = self.oov_index
        self.field = KnowledgeGraphField(self.graph, self.utterance,
                                         self.token_indexers, self.tokenizer)

        super(KnowledgeGraphFieldTest, self).setUp()

    def test_count_vocab_items(self):
        namespace_token_counts = defaultdict(lambda: defaultdict(int))
        self.field.count_vocab_items(namespace_token_counts)

        assert namespace_token_counts["tokens"] == {
            '-1': 1,
            '0': 1,
            '1': 1,
            'name': 1,
            'in': 2,
            'english': 2,
            'location': 1,
            'paradeniz': 1,
            'mersin': 1,
            'lake': 1,
            'gala': 1,
            'edirne': 1,
        }

    def test_index_converts_field_correctly(self):
        # pylint: disable=protected-access
        self.field.index(self.vocab)
        assert self.field._indexed_entity_texts.keys() == {'tokens'}
        # Note that these are sorted by their _identifiers_, not their cell text, so the
        # `fb:row.rows` show up after the `fb:cells`.
        expected_array = [[self.negative_one_index], [self.zero_index],
                          [self.one_index], [self.edirne_index],
                          [self.lake_index, self.gala_index],
                          [self.mersin_index], [self.paradeniz_index],
                          [
                              self.location_index, self.in_index,
                              self.english_index
                          ],
                          [self.name_index, self.in_index, self.english_index]]
        assert self.field._indexed_entity_texts['tokens'] == expected_array

    def test_get_padding_lengths_raises_if_not_indexed(self):
        with pytest.raises(AssertionError):
            self.field.get_padding_lengths()

    def test_padding_lengths_are_computed_correctly(self):
        # pylint: disable=protected-access
        self.field.index(self.vocab)
        assert self.field.get_padding_lengths() == {
            'num_entities': 9,
            'num_entity_tokens': 3,
            'num_utterance_tokens': 4
        }
        self.field._token_indexers[
            'token_characters'] = TokenCharactersIndexer(min_padding_length=1)
        self.field.index(self.vocab)
        assert self.field.get_padding_lengths() == {
            'num_entities': 9,
            'num_entity_tokens': 3,
            'num_utterance_tokens': 4,
            'num_token_characters': 9
        }

    def test_as_tensor_produces_correct_output(self):
        self.field.index(self.vocab)
        padding_lengths = self.field.get_padding_lengths()
        padding_lengths['num_utterance_tokens'] += 1
        padding_lengths['num_entities'] += 1
        tensor_dict = self.field.as_tensor(padding_lengths)
        assert tensor_dict.keys() == {'text', 'linking'}
        expected_text_tensor = [
            [self.negative_one_index, 0, 0], [self.zero_index, 0, 0],
            [self.one_index, 0, 0], [self.edirne_index, 0, 0],
            [self.lake_index, self.gala_index, 0], [self.mersin_index, 0, 0],
            [self.paradeniz_index, 0, 0],
            [self.location_index, self.in_index, self.english_index],
            [self.name_index, self.in_index, self.english_index], [0, 0, 0]
        ]
        assert_almost_equal(
            tensor_dict['text']['tokens'].detach().cpu().numpy(),
            expected_text_tensor)

        linking_tensor = tensor_dict['linking'].detach().cpu().numpy()
        expected_linking_tensor = [
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # -1, "where"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # -1, "is"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # -1, "mersin"
                [0, 0, 0, 0, 0, -1, 0, 0, 0, 0]
            ],  # -1, "?"
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 0, "where"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 0, "is"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 0, "mersin"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # 0, "?"
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 1, "where"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 1, "is"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 1, "mersin"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # 1, "?"
            [
                [0, 0, 0, 0, 0, .2, 0, 0, 0, 0],  # fb:cell.edirne, "where"
                [0, 0, 0, 0, 0, -1.5, 0, 0, 0, 0],  # fb:cell.edirne, "is"
                [0, 0, 0, 0, 0, .1666, 0, 0, 0, 0],  # fb:cell.edirne, "mersin"
                [0, 0, 0, 0, 0, -5, 0, 0, 0, 0],  # fb:cell.edirne, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # fb:cell.edirne, padding
            [
                [0, 0, 0, 0, 0, -.6, 0, 0, 0, 0],  # fb:cell.lake_gala, "where"
                [0, 0, 0, 0, 0, -3.5, 0, 0, 0, 0],  # fb:cell.lake_gala, "is"
                [0, 0, 0, 0, 0, -.3333, 0, 0, 0,
                 0],  # fb:cell.lake_gala, "mersin"
                [0, 0, 0, 0, 0, -8, 0, 0, 0, 0],  # fb:cell.lake_gala, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # fb:cell.lake_gala, padding
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # fb:cell.mersin, "where"
                [0, 0, 0, 0, 0, -1.5, 0, 0, 0, 0],  # fb:cell.mersin, "is"
                [0, 1, 1, 1, 1, 1, 0, 0, 1, 1],  # fb:cell.mersin, "mersin"
                [0, 0, 0, 0, 0, -5, 0, 0, 0, 0],  # fb:cell.mersin, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # fb:cell.mersin, padding
            [
                [0, 0, 0, 0, 0, -.6, 0, 0, 0, 0],  # fb:cell.paradeniz, "where"
                [0, 0, 0, 0, 0, -3, 0, 0, 0, 0],  # fb:cell.paradeniz, "is"
                [0, 0, 0, 0, 0, -.1666, 0, 0, 0,
                 0],  # fb:cell.paradeniz, "mersin"
                [0, 0, 0, 0, 0, -8, 0, 0, 0, 0],  # fb:cell.paradeniz, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # fb:cell.paradeniz, padding
            [
                [0, 0, 0, 0, 0, -2.6, 0, 0, 0,
                 0],  # fb:row.row.name_in_english, "where"
                [0, 0, 0, 0, 0, -7.5, 0, 0, 0,
                 0],  # fb:row.row.name_in_english, "is"
                [0, 0, 0, 0, 0, -1.8333, 1, 1, 0,
                 0],  # fb:row.row.name_in_english, "mersin"
                [0, 0, 0, 0, 0, -18, 0, 0, 0,
                 0],  # fb:row.row.name_in_english, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # fb:row.row.name_in_english, padding
            [
                [0, 0, 0, 0, 0, -1.6, 0, 0, 0,
                 0],  # fb:row.row.location_in_english, "where"
                [0, 0, 0, 0, 0, -5.5, 0, 0, 0,
                 0],  # fb:row.row.location_in_english, "is"
                [0, 0, 0, 0, 0, -1, 0, 0, 0,
                 0],  # fb:row.row.location_in_english, "mersin"
                [0, 0, 0, 0, 0, -14, 0, 0, 0,
                 0],  # fb:row.row.location_in_english, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # fb:row.row.location_in_english, padding
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "where"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "is"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "mersin"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ]
        ]  # padding, padding
        for entity_index, entity_features in enumerate(
                expected_linking_tensor):
            for question_index, feature_vector in enumerate(entity_features):
                assert_almost_equal(linking_tensor[entity_index,
                                                   question_index],
                                    feature_vector,
                                    decimal=4,
                                    err_msg=f"{entity_index} {question_index}")

    def test_lemma_feature_extractor(self):
        # pylint: disable=protected-access
        utterance = self.tokenizer.tokenize("Names in English")
        field = KnowledgeGraphField(self.graph, self.utterance,
                                    self.token_indexers, self.tokenizer)
        entity = 'fb:row.row.name_in_english'
        lemma_feature = field._contains_lemma_match(
            entity, field._entity_text_map[entity], utterance[0], 0, utterance)
        assert lemma_feature == 1

    def test_span_overlap_fraction(self):
        # pylint: disable=protected-access
        utterance = self.tokenizer.tokenize(
            "what is the name in english of mersin?")
        field = KnowledgeGraphField(self.graph, self.utterance,
                                    self.token_indexers, self.tokenizer)
        entity = 'fb:row.row.name_in_english'
        entity_text = field._entity_text_map[entity]
        feature_values = [
            field._span_overlap_fraction(entity, entity_text, token, i,
                                         utterance)
            for i, token in enumerate(utterance)
        ]
        assert feature_values == [0, 0, 0, 1, 1, 1, 0, 0, 0]

    def test_batch_tensors(self):
        self.field.index(self.vocab)
        padding_lengths = self.field.get_padding_lengths()
        tensor_dict1 = self.field.as_tensor(padding_lengths)
        tensor_dict2 = self.field.as_tensor(padding_lengths)
        batched_tensor_dict = self.field.batch_tensors(
            [tensor_dict1, tensor_dict2])
        assert batched_tensor_dict.keys() == {'text', 'linking'}
        expected_single_tensor = [
            [self.negative_one_index, 0, 0], [self.zero_index, 0, 0],
            [self.one_index, 0, 0], [self.edirne_index, 0, 0],
            [self.lake_index, self.gala_index, 0], [self.mersin_index, 0, 0],
            [self.paradeniz_index, 0, 0],
            [self.location_index, self.in_index, self.english_index],
            [self.name_index, self.in_index, self.english_index]
        ]
        expected_batched_tensor = [
            expected_single_tensor, expected_single_tensor
        ]
        assert_almost_equal(
            batched_tensor_dict['text']['tokens'].detach().cpu().numpy(),
            expected_batched_tensor)
        expected_linking_tensor = torch.stack(
            [tensor_dict1['linking'], tensor_dict2['linking']])
        assert_almost_equal(
            batched_tensor_dict['linking'].detach().cpu().numpy(),
            expected_linking_tensor.detach().cpu().numpy())
 def setUp(self):
     super().setUp()
     self.tokenizer = WordTokenizer(SpacyWordSplitter(pos_tags=True))
예제 #48
0
파일: trainer.py 프로젝트: BrianTin/knu_ci
    def __init__(self, training=False):
        self.training = training
        config = conf['seq2seq_allen']
        prefix = config['processed_data_prefix']
        train_file = config['train_data']
        valid_file = config['valid_data']
        src_embedding_dim = config['src_embedding_dim']
        hidden_dim = config['hidden_dim']
        batch_size = config['batch_size']
        epoch = config['epoch']
        self.model_path = config['model']

        if torch.cuda.is_available():
            cuda_device = 0
        else:
            cuda_device = -1

        # 定义数据读取器,WordTokenizer代表按照空格分割,target的namespace用于生成输出层的vocab时不和source混在一起
        self.reader = MySeqDatasetReader(
            source_tokenizer=WordTokenizer(),
            target_tokenizer=WordTokenizer(),
            source_token_indexers={'tokens': SingleIdTokenIndexer()},
            target_token_indexers={
                'tokens': SingleIdTokenIndexer(namespace='target_tokens')
            })

        if training and self.model_path is not None:
            # 从文件中读取数据
            self.train_dataset = self.reader.read(
                os.path.join(prefix, train_file))
            self.valid_dataset = self.reader.read(
                os.path.join(prefix, valid_file))

            # 定义词汇
            self.vocab = Vocabulary.from_instances(self.train_dataset +
                                                   self.valid_dataset,
                                                   min_count={
                                                       'tokens': 3,
                                                       'target_tokens': 3
                                                   })
        elif not training:
            try:
                self.vocab = Vocabulary.from_files(self.model_path)
            except Exception as e:
                logger.exception('vocab file does not exist!')

                # 从文件中读取数据
                self.train_dataset = self.reader.read(
                    os.path.join(prefix, train_file))
                self.valid_dataset = self.reader.read(
                    os.path.join(prefix, valid_file))

                # 定义词汇
                self.vocab = Vocabulary.from_instances(self.train_dataset +
                                                       self.valid_dataset,
                                                       min_count={
                                                           'tokens': 3,
                                                           'target_tokens': 3
                                                       })

        # 定义embedding层
        src_embedding = Embedding(
            num_embeddings=self.vocab.get_vocab_size('tokens'),
            embedding_dim=src_embedding_dim)

        # 定义encoder,这里使用的是BiGRU
        encoder = PytorchSeq2SeqWrapper(
            torch.nn.GRU(src_embedding_dim,
                         hidden_dim // 2,
                         batch_first=True,
                         bidirectional=True))

        # 定义decoder,这里使用的是GRU,因为decoder的输入需要和encoder的输出一致
        decoder = PytorchSeq2SeqWrapper(
            torch.nn.GRU(hidden_dim, hidden_dim, batch_first=True))
        # 将index 映射到 embedding上,tokens与data reader中用的TokenInder一致
        source_embedder = BasicTextFieldEmbedder({"tokens": src_embedding})

        # 线性Attention层
        attention = LinearAttention(hidden_dim,
                                    hidden_dim,
                                    activation=Activation.by_name('tanh')())

        # 定义模型
        self.model = Seq2SeqKnu(vocab=self.vocab,
                                source_embedder=source_embedder,
                                encoder=encoder,
                                target_namespace='target_tokens',
                                decoder=decoder,
                                attention=attention,
                                max_decoding_steps=20,
                                cuda_device=cuda_device)

        # 判断是否训练
        if training and self.model_path is not None:
            optimizer = optim.Adam(self.model.parameters())
            # sorting_keys代表batch的时候依据什么排序
            iterator = BucketIterator(batch_size=batch_size,
                                      sorting_keys=[("source_tokens",
                                                     "num_tokens")])
            # 迭代器需要接受vocab,在训练时可以用vocab来index数据
            iterator.index_with(self.vocab)

            self.model.cuda(cuda_device)

            # 定义训练器
            self.trainer = Trainer(model=self.model,
                                   optimizer=optimizer,
                                   iterator=iterator,
                                   patience=10,
                                   validation_metric="+accuracy",
                                   train_dataset=self.train_dataset,
                                   validation_dataset=self.valid_dataset,
                                   serialization_dir=self.model_path,
                                   num_epochs=epoch,
                                   cuda_device=cuda_device)
        elif not training:
            with open(os.path.join(self.model_path, 'best.th'), 'rb') as f:
                self.model.load_state_dict(torch.load(f))
            self.model.cuda(cuda_device)
            self.predictor = MySeqPredictor(self.model,
                                            dataset_reader=self.reader)
class TestTableQuestionContext(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        self.tokenizer = WordTokenizer(SpacyWordSplitter(pos_tags=True))

    def test_table_data(self):
        question = "what was the attendance when usl a league played?"
        question_tokens = self.tokenizer.tokenize(question)
        test_file = f'{self.FIXTURES_ROOT}/data/wikitables/sample_table.tagged'
        table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens)
        assert table_question_context.table_data == [{'date_column:year': '2001',
                                                      'number_column:division': '2',
                                                      'string_column:league': 'usl_a_league',
                                                      'string_column:regular_season': '4th_western',
                                                      'string_column:playoffs': 'quarterfinals',
                                                      'string_column:open_cup': 'did_not_qualify',
                                                      'number_column:avg_attendance': '7_169'},
                                                     {'date_column:year': '2005',
                                                      'number_column:division': '2',
                                                      'string_column:league': 'usl_first_division',
                                                      'string_column:regular_season': '5th',
                                                      'string_column:playoffs': 'quarterfinals',
                                                      'string_column:open_cup': '4th_round',
                                                      'number_column:avg_attendance': '6_028'}]

    def test_number_extraction(self):
        question = """how many players on the 191617 illinois fighting illini men's basketball team
                      had more than 100 points scored?"""
        question_tokens = self.tokenizer.tokenize(question)
        test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-7.table'
        table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens)
        _, number_entities = table_question_context.get_entities_from_question()
        assert number_entities == [("191617", 5), ("100", 16)]

    def test_date_extraction(self):
        question = "how many laps did matt kenset complete on february 26, 2006."
        question_tokens = self.tokenizer.tokenize(question)
        test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-8.table'
        table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens)
        _, number_entities = table_question_context.get_entities_from_question()
        assert number_entities == [("2", 8), ("26", 9), ("2006", 11)]

    def test_date_extraction_2(self):
        question = """how many different players scored for the san jose earthquakes during their
                      1979 home opener against the timbers?"""
        question_tokens = self.tokenizer.tokenize(question)
        test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-6.table'
        table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens)
        _, number_entities = table_question_context.get_entities_from_question()
        assert number_entities == [("1979", 12)]

    def test_multiword_entity_extraction(self):
        question = "was the positioning better the year of the france venue or the year of the south korea venue?"
        question_tokens = self.tokenizer.tokenize(question)
        test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-3.table'
        table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens)
        entities, _ = table_question_context.get_entities_from_question()
        assert entities == [("string:france", "string_column:venue"),
                            ("string:south_korea", "string_column:venue")]

    def test_rank_number_extraction(self):
        question = "what was the first tamil-language film in 1943?"
        question_tokens = self.tokenizer.tokenize(question)
        test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-1.table'
        table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens)
        _, numbers = table_question_context.get_entities_from_question()
        assert numbers == [("1", 3), ('1943', 9)]

    def test_null_extraction(self):
        question = "on what date did the eagles score the least points?"
        question_tokens = self.tokenizer.tokenize(question)
        test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-2.table'
        table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens)
        entities, numbers = table_question_context.get_entities_from_question()
        # "Eagles" does not appear in the table.
        assert entities == []
        assert numbers == []

    def test_numerical_column_type_extraction(self):
        question = """how many players on the 191617 illinois fighting illini men's basketball team
                      had more than 100 points scored?"""
        question_tokens = self.tokenizer.tokenize(question)
        test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-7.table'
        table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens)
        predicted_types = table_question_context.column_types
        assert predicted_types["games_played"] == "number"
        assert predicted_types["field_goals"] == "number"
        assert predicted_types["free_throws"] == "number"
        assert predicted_types["points"] == "number"

    def test_date_column_type_extraction_1(self):
        question = "how many were elected?"
        question_tokens = self.tokenizer.tokenize(question)
        test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-5.table'
        table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens)
        predicted_types = table_question_context.column_types
        assert predicted_types["first_elected"] == "date"

    def test_date_column_type_extraction_2(self):
        question = "how many were elected?"
        question_tokens = self.tokenizer.tokenize(question)
        test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-9.table'
        table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens)
        predicted_types = table_question_context.column_types
        assert predicted_types["date_of_appointment"] == "date"
        assert predicted_types["date_of_election"] == "date"

    def test_string_column_types_extraction(self):
        question = "how many were elected?"
        question_tokens = self.tokenizer.tokenize(question)
        test_file = f'{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-10.table'
        table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens)
        predicted_types = table_question_context.column_types
        assert predicted_types["birthplace"] == "string"
        assert predicted_types["advocate"] == "string"
        assert predicted_types["notability"] == "string"
        assert predicted_types["name"] == "string"

    def test_number_and_entity_extraction(self):
        question = "other than m1 how many notations have 1 in them?"
        question_tokens = self.tokenizer.tokenize(question)
        test_file = f"{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-11.table"
        table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens)
        string_entities, number_entities = table_question_context.get_entities_from_question()
        assert string_entities == [("string:m1", "string_column:notation")]
        assert number_entities == [("1", 2), ("1", 7)]

    def test_get_knowledge_graph(self):
        question = "other than m1 how many notations have 1 in them?"
        question_tokens = self.tokenizer.tokenize(question)
        test_file = f"{self.FIXTURES_ROOT}/data/corenlp_processed_tables/TEST-11.table"
        table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens)
        knowledge_graph = table_question_context.get_table_knowledge_graph()
        entities = knowledge_graph.entities
        # -1 is not in entities because there are no date columns in the table.
        assert sorted(entities) == ['1', 'number_column:position', 'string:m1',
                                    'string_column:mnemonic', 'string_column:notation',
                                    'string_column:short_name', 'string_column:swara']
        neighbors = knowledge_graph.neighbors
        # Each number extracted from the question will have all number and date columns as
        # neighbors. Each string entity extracted from the question will only have the corresponding
        # column as the neighbor.
        assert neighbors == {'1': ['number_column:position'],
                             'string_column:mnemonic': [],
                             'string_column:short_name': [],
                             'string_column:swara': [],
                             'number_column:position': ['1'],
                             'string:m1': ['string_column:notation'],
                             'string_column:notation': ['string:m1']}
        entity_text = knowledge_graph.entity_text
        assert entity_text == {'1': '1',
                               'string:m1': 'm1',
                               'string_column:notation': 'notation',
                               'string_column:mnemonic': 'mnemonic',
                               'string_column:short_name': 'short name',
                               'string_column:swara': 'swara',
                               'number_column:position': 'position'}


    def test_knowledge_graph_has_correct_neighbors(self):
        question = "when was the attendance greater than 5000?"
        question_tokens = self.tokenizer.tokenize(question)
        test_file = f'{self.FIXTURES_ROOT}/data/wikitables/sample_table.tagged'
        table_question_context = TableQuestionContext.read_from_file(test_file, question_tokens)
        knowledge_graph = table_question_context.get_table_knowledge_graph()
        neighbors = knowledge_graph.neighbors
        # '5000' is neighbors with number and date columns. '-1' is in entities because there is a
        # date column, which is its only neighbor.
        assert set(neighbors.keys()) == {'date_column:year', 'number_column:division',
                                         'string_column:league', 'string_column:regular_season',
                                         'string_column:playoffs', 'string_column:open_cup',
                                         'number_column:avg_attendance', '5000', '-1'}
        assert set(neighbors['date_column:year']) == {'5000', '-1'}
        assert neighbors['number_column:division'] == ['5000']
        assert neighbors['string_column:league'] == []
        assert neighbors['string_column:regular_season'] == []
        assert neighbors['string_column:playoffs'] == []
        assert neighbors['string_column:open_cup'] == []
        assert neighbors['number_column:avg_attendance'] == ['5000']
        assert set(neighbors['5000']) == {'date_column:year', 'number_column:division',
                                          'number_column:avg_attendance'}
        assert neighbors['-1'] == ['date_column:year']
class KnowledgeGraphFieldTest(AllenNlpTestCase):
    def setUp(self):
        self.tokenizer = WordTokenizer(SpacyWordSplitter(pos_tags=True))
        self.utterance = self.tokenizer.tokenize("where is mersin?")
        self.token_indexers = {"tokens": SingleIdTokenIndexer("tokens")}

        json = {
                'question': self.utterance,
                'columns': ['Name in English', 'Location in English'],
                'cells': [['Paradeniz', 'Mersin'],
                          ['Lake Gala', 'Edirne']]
                }
        self.graph = TableQuestionKnowledgeGraph.read_from_json(json)
        self.vocab = Vocabulary()
        self.name_index = self.vocab.add_token_to_namespace("name", namespace='tokens')
        self.in_index = self.vocab.add_token_to_namespace("in", namespace='tokens')
        self.english_index = self.vocab.add_token_to_namespace("english", namespace='tokens')
        self.location_index = self.vocab.add_token_to_namespace("location", namespace='tokens')
        self.paradeniz_index = self.vocab.add_token_to_namespace("paradeniz", namespace='tokens')
        self.mersin_index = self.vocab.add_token_to_namespace("mersin", namespace='tokens')
        self.lake_index = self.vocab.add_token_to_namespace("lake", namespace='tokens')
        self.gala_index = self.vocab.add_token_to_namespace("gala", namespace='tokens')
        self.negative_one_index = self.vocab.add_token_to_namespace("-1", namespace='tokens')
        self.zero_index = self.vocab.add_token_to_namespace("0", namespace='tokens')
        self.one_index = self.vocab.add_token_to_namespace("1", namespace='tokens')

        self.oov_index = self.vocab.get_token_index('random OOV string', namespace='tokens')
        self.edirne_index = self.oov_index
        self.field = KnowledgeGraphField(self.graph, self.utterance, self.token_indexers, self.tokenizer)

        super(KnowledgeGraphFieldTest, self).setUp()

    def test_count_vocab_items(self):
        namespace_token_counts = defaultdict(lambda: defaultdict(int))
        self.field.count_vocab_items(namespace_token_counts)

        assert namespace_token_counts["tokens"] == {
                '-1': 1,
                '0': 1,
                '1': 1,
                'name': 1,
                'in': 2,
                'english': 2,
                'location': 1,
                'paradeniz': 1,
                'mersin': 1,
                'lake': 1,
                'gala': 1,
                'edirne': 1,
                }

    def test_index_converts_field_correctly(self):
        # pylint: disable=protected-access
        self.field.index(self.vocab)
        assert self.field._indexed_entity_texts.keys() == {'tokens'}
        # Note that these are sorted by their _identifiers_, not their cell text, so the
        # `fb:row.rows` show up after the `fb:cells`.
        expected_array = [[self.negative_one_index],
                          [self.zero_index],
                          [self.one_index],
                          [self.edirne_index],
                          [self.lake_index, self.gala_index],
                          [self.mersin_index],
                          [self.paradeniz_index],
                          [self.location_index, self.in_index, self.english_index],
                          [self.name_index, self.in_index, self.english_index]]
        assert self.field._indexed_entity_texts['tokens'] == expected_array

    def test_get_padding_lengths_raises_if_not_indexed(self):
        with pytest.raises(AssertionError):
            self.field.get_padding_lengths()

    def test_padding_lengths_are_computed_correctly(self):
        # pylint: disable=protected-access
        self.field.index(self.vocab)
        assert self.field.get_padding_lengths() == {'num_entities': 9, 'num_entity_tokens': 3,
                                                    'num_utterance_tokens': 4}
        self.field._token_indexers['token_characters'] = TokenCharactersIndexer()
        self.field.index(self.vocab)
        assert self.field.get_padding_lengths() == {'num_entities': 9, 'num_entity_tokens': 3,
                                                    'num_utterance_tokens': 4,
                                                    'num_token_characters': 9}

    def test_as_tensor_produces_correct_output(self):
        self.field.index(self.vocab)
        padding_lengths = self.field.get_padding_lengths()
        padding_lengths['num_utterance_tokens'] += 1
        padding_lengths['num_entities'] += 1
        tensor_dict = self.field.as_tensor(padding_lengths)
        assert tensor_dict.keys() == {'text', 'linking'}
        expected_text_tensor = [[self.negative_one_index, 0, 0],
                                [self.zero_index, 0, 0],
                                [self.one_index, 0, 0],
                                [self.edirne_index, 0, 0],
                                [self.lake_index, self.gala_index, 0],
                                [self.mersin_index, 0, 0],
                                [self.paradeniz_index, 0, 0],
                                [self.location_index, self.in_index, self.english_index],
                                [self.name_index, self.in_index, self.english_index],
                                [0, 0, 0]]
        assert_almost_equal(tensor_dict['text']['tokens'].detach().cpu().numpy(), expected_text_tensor)

        linking_tensor = tensor_dict['linking'].detach().cpu().numpy()
        expected_linking_tensor = [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # -1, "where"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # -1, "is"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # -1, "mersin"
                                    [0, 0, 0, 0, 0, -1, 0, 0, 0, 0]],  # -1, "?"
                                   [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 0, "where"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 0, "is"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 0, "mersin"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # 0, "?"
                                   [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 1, "where"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 1, "is"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 1, "mersin"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # 1, "?"
                                   [[0, 0, 0, 0, 0, .2, 0, 0, 0, 0],  # fb:cell.edirne, "where"
                                    [0, 0, 0, 0, 0, -1.5, 0, 0, 0, 0],  # fb:cell.edirne, "is"
                                    [0, 0, 0, 0, 0, .1666, 0, 0, 0, 0],  # fb:cell.edirne, "mersin"
                                    [0, 0, 0, 0, 0, -5, 0, 0, 0, 0],  # fb:cell.edirne, "?"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # fb:cell.edirne, padding
                                   [[0, 0, 0, 0, 0, -.6, 0, 0, 0, 0],  # fb:cell.lake_gala, "where"
                                    [0, 0, 0, 0, 0, -3.5, 0, 0, 0, 0],  # fb:cell.lake_gala, "is"
                                    [0, 0, 0, 0, 0, -.3333, 0, 0, 0, 0],  # fb:cell.lake_gala, "mersin"
                                    [0, 0, 0, 0, 0, -8, 0, 0, 0, 0],  # fb:cell.lake_gala, "?"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # fb:cell.lake_gala, padding
                                   [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # fb:cell.mersin, "where"
                                    [0, 0, 0, 0, 0, -1.5, 0, 0, 0, 0],  # fb:cell.mersin, "is"
                                    [0, 1, 1, 1, 1, 1, 0, 0, 1, 1],  # fb:cell.mersin, "mersin"
                                    [0, 0, 0, 0, 0, -5, 0, 0, 0, 0],  # fb:cell.mersin, "?"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # fb:cell.mersin, padding
                                   [[0, 0, 0, 0, 0, -.6, 0, 0, 0, 0],  # fb:cell.paradeniz, "where"
                                    [0, 0, 0, 0, 0, -3, 0, 0, 0, 0],  # fb:cell.paradeniz, "is"
                                    [0, 0, 0, 0, 0, -.1666, 0, 0, 0, 0],  # fb:cell.paradeniz, "mersin"
                                    [0, 0, 0, 0, 0, -8, 0, 0, 0, 0],  # fb:cell.paradeniz, "?"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # fb:cell.paradeniz, padding
                                   [[0, 0, 0, 0, 0, -2.6, 0, 0, 0, 0],  # fb:row.row.name_in_english, "where"
                                    [0, 0, 0, 0, 0, -7.5, 0, 0, 0, 0],  # fb:row.row.name_in_english, "is"
                                    [0, 0, 0, 0, 0, -1.8333, 1, 1, 0, 0],  # fb:row.row.name_in_english, "mersin"
                                    [0, 0, 0, 0, 0, -18, 0, 0, 0, 0],  # fb:row.row.name_in_english, "?"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # fb:row.row.name_in_english, padding
                                   [[0, 0, 0, 0, 0, -1.6, 0, 0, 0, 0],  # fb:row.row.location_in_english, "where"
                                    [0, 0, 0, 0, 0, -5.5, 0, 0, 0, 0],  # fb:row.row.location_in_english, "is"
                                    [0, 0, 0, 0, 0, -1, 0, 0, 0, 0],  # fb:row.row.location_in_english, "mersin"
                                    [0, 0, 0, 0, 0, -14, 0, 0, 0, 0],  # fb:row.row.location_in_english, "?"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # fb:row.row.location_in_english, padding
                                   [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "where"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "is"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "mersin"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "?"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]  # padding, padding
        for entity_index, entity_features in enumerate(expected_linking_tensor):
            for question_index, feature_vector in enumerate(entity_features):
                assert_almost_equal(linking_tensor[entity_index, question_index],
                                    feature_vector,
                                    decimal=4,
                                    err_msg=f"{entity_index} {question_index}")

    def test_lemma_feature_extractor(self):
        # pylint: disable=protected-access
        utterance = self.tokenizer.tokenize("Names in English")
        field = KnowledgeGraphField(self.graph, self.utterance, self.token_indexers, self.tokenizer)
        entity = 'fb:row.row.name_in_english'
        lemma_feature = field._contains_lemma_match(entity,
                                                    field._entity_text_map[entity],
                                                    utterance[0],
                                                    0,
                                                    utterance)
        assert lemma_feature == 1

    def test_span_overlap_fraction(self):
        # pylint: disable=protected-access
        utterance = self.tokenizer.tokenize("what is the name in english of mersin?")
        field = KnowledgeGraphField(self.graph, self.utterance, self.token_indexers, self.tokenizer)
        entity = 'fb:row.row.name_in_english'
        entity_text = field._entity_text_map[entity]
        feature_values = [field._span_overlap_fraction(entity, entity_text, token, i, utterance)
                          for i, token in enumerate(utterance)]
        assert feature_values == [0, 0, 0, 1, 1, 1, 0, 0, 0]

    def test_batch_tensors(self):
        self.field.index(self.vocab)
        padding_lengths = self.field.get_padding_lengths()
        tensor_dict1 = self.field.as_tensor(padding_lengths)
        tensor_dict2 = self.field.as_tensor(padding_lengths)
        batched_tensor_dict = self.field.batch_tensors([tensor_dict1, tensor_dict2])
        assert batched_tensor_dict.keys() == {'text', 'linking'}
        expected_single_tensor = [[self.negative_one_index, 0, 0],
                                  [self.zero_index, 0, 0],
                                  [self.one_index, 0, 0],
                                  [self.edirne_index, 0, 0],
                                  [self.lake_index, self.gala_index, 0],
                                  [self.mersin_index, 0, 0],
                                  [self.paradeniz_index, 0, 0],
                                  [self.location_index, self.in_index, self.english_index],
                                  [self.name_index, self.in_index, self.english_index]]
        expected_batched_tensor = [expected_single_tensor, expected_single_tensor]
        assert_almost_equal(batched_tensor_dict['text']['tokens'].detach().cpu().numpy(),
                            expected_batched_tensor)
        expected_linking_tensor = torch.stack([tensor_dict1['linking'], tensor_dict2['linking']])
        assert_almost_equal(batched_tensor_dict['linking'].detach().cpu().numpy(),
                            expected_linking_tensor.detach().cpu().numpy())
class OpenIePredictor(Predictor):
    """
    Predictor for the :class: `models.SemanticRolelabeler` model (in its Open Information variant).
    Used by online demo and for prediction on an input file using command line.
    """
    def __init__(self,
                 model: Model,
                 dataset_reader: DatasetReader) -> None:
        super().__init__(model, dataset_reader)
        self._tokenizer = WordTokenizer(word_splitter=SpacyWordSplitter(pos_tags=True))


    def _json_to_instance(self, json_dict: JsonDict) -> Instance:
        """
        Expects JSON that looks like ``{"sentence": "...", "predicate_index": "..."}``.
        Assumes sentence is tokenized, and that predicate_index points to a specific
        predicate (word index) within the sentence, for which to produce Open IE extractions.
        """
        tokens = json_dict["sentence"]
        predicate_index = int(json_dict["predicate_index"])
        verb_labels = [0 for _ in tokens]
        verb_labels[predicate_index] = 1
        return self._dataset_reader.text_to_instance(tokens, verb_labels)

    @overrides
    def predict_json(self, inputs: JsonDict) -> JsonDict:
        """
        Create instance(s) after predicting the format. One sentence containing multiple verbs
        will lead to multiple instances.

        Expects JSON that looks like ``{"sentence": "..."}``

        Returns a JSON that looks like

        .. code-block:: js

            {"tokens": [...],
             "tag_spans": [{"ARG0": "...",
                            "V": "...",
                            "ARG1": "...",
                             ...}]}
        """
        sent_tokens = self._tokenizer.tokenize(inputs["sentence"])

        # Find all verbs in the input sentence
        pred_ids = [i for (i, t)
                    in enumerate(sent_tokens)
                    if t.pos_ == "VERB"]

        # Create instances
        instances = [self._json_to_instance({"sentence": sent_tokens,
                                             "predicate_index": pred_id})
                     for pred_id in pred_ids]

        # Run model
        outputs = [[sanitize_label(label) for label in self._model.forward_on_instance(instance)["tags"]]
                   for instance in instances]

        # Consolidate predictions
        pred_dict = consolidate_predictions(outputs, sent_tokens)

        # Build and return output dictionary
        results = {"verbs": [], "words": sent_tokens}

        for tags in pred_dict.values():
            # Join multi-word predicates
            tags = join_mwp(tags)

            # Create description text
            description = make_oie_string(sent_tokens, tags)

            # Add a predicate prediction to the return dictionary.
            results["verbs"].append({
                    "verb": get_predicate_text(sent_tokens, tags),
                    "description": description,
                    "tags": tags,
            })

        return sanitize(results)
 def __init__(self,
              model: Model,
              dataset_reader: DatasetReader) -> None:
     super().__init__(model, dataset_reader)
     self._tokenizer = WordTokenizer(word_splitter=SpacyWordSplitter(pos_tags=True))