示例#1
0
    def get_line_ents(self, lines: 'List[str]'):
        """ Extracts list of entities from lines retrieved from database
        """
        self.create_lines_file(lines)  # To accomodate dataset generation
        dataset = NERDataset('./lines.txt', self.char_vocab, self.tag_vocab)
        data_iterator = DataIterator(dataset)
        line_tokens, decoded_tags = [], []

        for batch in data_iterator:
            line_tokens += batch['text']
            tag_ids = self.ner_model(batch)
            batch_decoded_tags = []
            for line_tag_ids in tag_ids:
                line_decoded_tags = []
                for decoded_id in line_tag_ids:
                    line_decoded_tags.append(
                        self.tag_vocab.id_to_token_map_py[decoded_id + 3])
                batch_decoded_tags.append(line_decoded_tags)
            line_tokens += batch['text']
            decoded_tags += batch_decoded_tags

        # Ensure decoded tags are of same sequence length as tokens
        tagged_lines = zip(line_tokens, decoded_tags)
        tagged_lines = [(tokens, tags[:len(tokens)])
                        for tokens, tags in tagged_lines]

        entities = []
        for tagged_line in tagged_lines:
            entities += get_entities(tagged_line)
        return entities
示例#2
0
    def _run_and_test(self, hparams, test_transform=False):
        # Construct database
        scalar_data = ScalarData(hparams)

        self.assertEqual(scalar_data.list_items()[0],
                         hparams["dataset"]["data_name"])

        iterator = DataIterator(scalar_data)

        i = 0
        for batch in iterator:
            self.assertEqual(set(batch.keys()),
                             set(scalar_data.list_items()))
            value = batch[scalar_data.data_name][0]
            if test_transform:
                self.assertEqual(2 * i, value)
            else:
                self.assertEqual(i, value)
            i += 1
            data_type = hparams["dataset"]["data_type"]
            if data_type == "int":
                self.assertEqual(value.dtype, torch.int32)
            elif data_type == "float":
                self.assertEqual(value.dtype, torch.float32)
            elif data_type == "bool":
                self.assertTrue(value.dtype, torch_bool)
            self.assertIsInstance(value, torch.Tensor)
示例#3
0
    def test_shuffle(self):
        """Tests results of toggling shuffle.
        """
        hparams = copy.copy(self._int_hparams)
        hparams["batch_size"] = 10
        scalar_data = ScalarData(hparams)
        iterator = DataIterator(scalar_data)

        hparams_sfl = copy.copy(hparams)
        hparams_sfl["shuffle"] = True
        scalar_data_sfl = ScalarData(hparams_sfl)
        iterator_sfl = DataIterator(scalar_data_sfl)

        vals = []
        vals_sfl = []

        for batch, batch_sfl in zip(iterator, iterator_sfl):
            vals += batch["label"].tolist()
            vals_sfl += batch_sfl["label"].tolist()

        self.assertEqual(len(vals), len(vals_sfl))
        self.assertSetEqual(set(vals), set(vals_sfl))
示例#4
0
    def _build_dataset_iterator(self) \
            -> DataIterator:
        scope: Type[EntryType] = self._request["scope"]  # type: ignore
        schemes: Dict[str, Dict[str, Any]] = self._request["schemes"]

        data_source = DataPackIterator(pack_iterator=iter(self._cached_packs),
                                       context_type=scope,
                                       request={scope: []})

        dataset = DataPackDataset(data_source, schemes, self._config.dataset,
                                  self.device)
        iterator = DataIterator(dataset)

        return iterator
示例#5
0
    def predict_lines(self, cur, lines: 'List[Tuple]'):
        """ Updates dl_prediction for lines retrieved from database
        - Processes into batches given lines
        """
        self.create_lines_file(lines)

        data = Dataset('./lines.txt',
                       self.vocab,
                       self.label_vocab,
                       self.tag_vocab,
                       hparams={'batch_size': 32})
        data_iterator = DataIterator(data)
        for batch in data_iterator:  # Batch predictions
            lines_logits = self.line_classifier(batch)
            for i, logits in enumerate(lines_logits):
                page_id = batch.pg_ids[i]
                prediction = self.get_label(logits)
                cur.execute("UPDATE PageLines SET dl_prediction=? WHERE id=?",
                            (prediction, page_id))
示例#6
0
def main():

    config = yaml.safe_load(open("config.yml", "r"))
    config = HParams(config, default_hparams=None)

    if not os.path.exists(config.indexer.model_dir):
        print(f"Creating a new index...")
        encoder = BERTEncoder(pretrained_model_name="bert-base-uncased")
        encoder.to(device)

        feature_original_types = {
            "id": ["int64", "FixedLenFeature"],
            "input_ids": ["int64", "FixedLenFeature",
                          config.indexer.max_seq_length],
            "segment_ids": ["int64", "FixedLenFeature",
                            config.indexer.max_seq_length],
            "text": ["str", "FixedLenFeature"]
        }

        hparam = {
            "allow_smaller_final_batch": True,
            "batch_size": config.indexer.batch_size,
            "dataset": {
                "data_name": "data",
                "feature_original_types": feature_original_types,
                "files": config.indexer.pickle_data_dir
            },
            "shuffle": False
        }

        print(f"Embedding the text using BERTEncoder...")
        record_data = RecordData(hparams=hparam, device=device)
        data_iterator = DataIterator(record_data)
        index = EmbeddingBasedIndexer(hparams={
            "index_type": "GpuIndexFlatIP",
            "dim": 768,
            "device": "gpu0"
        })

        for idx, batch in enumerate(data_iterator):
            ids = batch["id"]
            input_ids = batch["input_ids"]
            segment_ids = batch["segment_ids"]
            text = batch["text"]
            _, pooled_output = get_embeddings(encoder, input_ids, segment_ids)
            index.add(vectors=pooled_output,
                      meta_data={k.item(): v for k, v in zip(ids, text)})

            if (idx + 1) % 50 == 0:
                print(f"Completed {idx+1} batches of size "
                      f"{config.indexer.batch_size}")

        index.save(path=config.indexer.model_dir)

    resource = Resources()
    query_pipeline = Pipeline(resource=resource)
    query_pipeline.set_reader(MultiPackTerminalReader())

    query_pipeline.add_processor(
        processor=MachineTranslationProcessor(), config=config.translator)
    query_pipeline.add_processor(
        processor=QueryCreator(), config=config.query_creator)
    query_pipeline.add_processor(
        processor=SearchProcessor(), config=config.indexer)
    query_pipeline.add_processor(
        processor=NLTKSentenceSegmenter(),
        selector=NameMatchSelector(select_name="doc_0"))
    query_pipeline.add_processor(
        processor=NLTKWordTokenizer(),
        selector=NameMatchSelector(select_name="doc_0"))
    query_pipeline.add_processor(
        processor=NLTKPOSTagger(),
        selector=NameMatchSelector(select_name="doc_0"))
    query_pipeline.add_processor(
        processor=SRLPredictor(), config=config.SRL,
        selector=NameMatchSelector(select_name="doc_0"))
    # query_pipeline.add_processor(
    #    processor=CoNLLNERPredictor(), config=config.NER,
    #    selector=NameMatchSelector(select_name="doc_0"))
    query_pipeline.add_processor(
        processor=MachineTranslationProcessor(), config=config.back_translator)

    query_pipeline.initialize()

    for m_pack in query_pipeline.process_dataset():

        # update resource to be used in the next conversation
        query_pack = m_pack.get_pack("query")
        if resource.get("user_utterance"):
            resource.get("user_utterance").append(query_pack)
        else:
            resource.update(user_utterance=[query_pack])

        response_pack = m_pack.get_pack("response")

        if resource.get("bot_utterance"):
            resource.get("bot_utterance").append(response_pack)
        else:
            resource.update(bot_utterance=[response_pack])

        english_pack = m_pack.get_pack("pack")
        print(colored("English Translation of the query: ", "green"),
              english_pack.text, "\n")
        pack = m_pack.get_pack("doc_0")
        print(colored("Retrieved Document", "green"), pack.text, "\n")
        print(colored("German Translation", "green"),
              m_pack.get_pack("response").text, "\n")
        for sentence in pack.get(Sentence):
            sent_text = sentence.text
            print(colored("Sentence:", 'red'), sent_text, "\n")

            print(colored("Semantic role labels:", 'red'))
            for link in pack.get(PredicateLink, sentence):
                parent = link.get_parent()
                child = link.get_child()
                print(f"  - \"{child.text}\" is role {link.arg_type} of "
                      f"predicate \"{parent.text}\"")
            print()

            input(colored("Press ENTER to continue...\n", 'green'))