def get_line_ents(self, lines: 'List[str]'): """ Extracts list of entities from lines retrieved from database """ self.create_lines_file(lines) # To accomodate dataset generation dataset = NERDataset('./lines.txt', self.char_vocab, self.tag_vocab) data_iterator = DataIterator(dataset) line_tokens, decoded_tags = [], [] for batch in data_iterator: line_tokens += batch['text'] tag_ids = self.ner_model(batch) batch_decoded_tags = [] for line_tag_ids in tag_ids: line_decoded_tags = [] for decoded_id in line_tag_ids: line_decoded_tags.append( self.tag_vocab.id_to_token_map_py[decoded_id + 3]) batch_decoded_tags.append(line_decoded_tags) line_tokens += batch['text'] decoded_tags += batch_decoded_tags # Ensure decoded tags are of same sequence length as tokens tagged_lines = zip(line_tokens, decoded_tags) tagged_lines = [(tokens, tags[:len(tokens)]) for tokens, tags in tagged_lines] entities = [] for tagged_line in tagged_lines: entities += get_entities(tagged_line) return entities
def _run_and_test(self, hparams, test_transform=False): # Construct database scalar_data = ScalarData(hparams) self.assertEqual(scalar_data.list_items()[0], hparams["dataset"]["data_name"]) iterator = DataIterator(scalar_data) i = 0 for batch in iterator: self.assertEqual(set(batch.keys()), set(scalar_data.list_items())) value = batch[scalar_data.data_name][0] if test_transform: self.assertEqual(2 * i, value) else: self.assertEqual(i, value) i += 1 data_type = hparams["dataset"]["data_type"] if data_type == "int": self.assertEqual(value.dtype, torch.int32) elif data_type == "float": self.assertEqual(value.dtype, torch.float32) elif data_type == "bool": self.assertTrue(value.dtype, torch_bool) self.assertIsInstance(value, torch.Tensor)
def test_shuffle(self): """Tests results of toggling shuffle. """ hparams = copy.copy(self._int_hparams) hparams["batch_size"] = 10 scalar_data = ScalarData(hparams) iterator = DataIterator(scalar_data) hparams_sfl = copy.copy(hparams) hparams_sfl["shuffle"] = True scalar_data_sfl = ScalarData(hparams_sfl) iterator_sfl = DataIterator(scalar_data_sfl) vals = [] vals_sfl = [] for batch, batch_sfl in zip(iterator, iterator_sfl): vals += batch["label"].tolist() vals_sfl += batch_sfl["label"].tolist() self.assertEqual(len(vals), len(vals_sfl)) self.assertSetEqual(set(vals), set(vals_sfl))
def _build_dataset_iterator(self) \ -> DataIterator: scope: Type[EntryType] = self._request["scope"] # type: ignore schemes: Dict[str, Dict[str, Any]] = self._request["schemes"] data_source = DataPackIterator(pack_iterator=iter(self._cached_packs), context_type=scope, request={scope: []}) dataset = DataPackDataset(data_source, schemes, self._config.dataset, self.device) iterator = DataIterator(dataset) return iterator
def predict_lines(self, cur, lines: 'List[Tuple]'): """ Updates dl_prediction for lines retrieved from database - Processes into batches given lines """ self.create_lines_file(lines) data = Dataset('./lines.txt', self.vocab, self.label_vocab, self.tag_vocab, hparams={'batch_size': 32}) data_iterator = DataIterator(data) for batch in data_iterator: # Batch predictions lines_logits = self.line_classifier(batch) for i, logits in enumerate(lines_logits): page_id = batch.pg_ids[i] prediction = self.get_label(logits) cur.execute("UPDATE PageLines SET dl_prediction=? WHERE id=?", (prediction, page_id))
def main(): config = yaml.safe_load(open("config.yml", "r")) config = HParams(config, default_hparams=None) if not os.path.exists(config.indexer.model_dir): print(f"Creating a new index...") encoder = BERTEncoder(pretrained_model_name="bert-base-uncased") encoder.to(device) feature_original_types = { "id": ["int64", "FixedLenFeature"], "input_ids": ["int64", "FixedLenFeature", config.indexer.max_seq_length], "segment_ids": ["int64", "FixedLenFeature", config.indexer.max_seq_length], "text": ["str", "FixedLenFeature"] } hparam = { "allow_smaller_final_batch": True, "batch_size": config.indexer.batch_size, "dataset": { "data_name": "data", "feature_original_types": feature_original_types, "files": config.indexer.pickle_data_dir }, "shuffle": False } print(f"Embedding the text using BERTEncoder...") record_data = RecordData(hparams=hparam, device=device) data_iterator = DataIterator(record_data) index = EmbeddingBasedIndexer(hparams={ "index_type": "GpuIndexFlatIP", "dim": 768, "device": "gpu0" }) for idx, batch in enumerate(data_iterator): ids = batch["id"] input_ids = batch["input_ids"] segment_ids = batch["segment_ids"] text = batch["text"] _, pooled_output = get_embeddings(encoder, input_ids, segment_ids) index.add(vectors=pooled_output, meta_data={k.item(): v for k, v in zip(ids, text)}) if (idx + 1) % 50 == 0: print(f"Completed {idx+1} batches of size " f"{config.indexer.batch_size}") index.save(path=config.indexer.model_dir) resource = Resources() query_pipeline = Pipeline(resource=resource) query_pipeline.set_reader(MultiPackTerminalReader()) query_pipeline.add_processor( processor=MachineTranslationProcessor(), config=config.translator) query_pipeline.add_processor( processor=QueryCreator(), config=config.query_creator) query_pipeline.add_processor( processor=SearchProcessor(), config=config.indexer) query_pipeline.add_processor( processor=NLTKSentenceSegmenter(), selector=NameMatchSelector(select_name="doc_0")) query_pipeline.add_processor( processor=NLTKWordTokenizer(), selector=NameMatchSelector(select_name="doc_0")) query_pipeline.add_processor( processor=NLTKPOSTagger(), selector=NameMatchSelector(select_name="doc_0")) query_pipeline.add_processor( processor=SRLPredictor(), config=config.SRL, selector=NameMatchSelector(select_name="doc_0")) # query_pipeline.add_processor( # processor=CoNLLNERPredictor(), config=config.NER, # selector=NameMatchSelector(select_name="doc_0")) query_pipeline.add_processor( processor=MachineTranslationProcessor(), config=config.back_translator) query_pipeline.initialize() for m_pack in query_pipeline.process_dataset(): # update resource to be used in the next conversation query_pack = m_pack.get_pack("query") if resource.get("user_utterance"): resource.get("user_utterance").append(query_pack) else: resource.update(user_utterance=[query_pack]) response_pack = m_pack.get_pack("response") if resource.get("bot_utterance"): resource.get("bot_utterance").append(response_pack) else: resource.update(bot_utterance=[response_pack]) english_pack = m_pack.get_pack("pack") print(colored("English Translation of the query: ", "green"), english_pack.text, "\n") pack = m_pack.get_pack("doc_0") print(colored("Retrieved Document", "green"), pack.text, "\n") print(colored("German Translation", "green"), m_pack.get_pack("response").text, "\n") for sentence in pack.get(Sentence): sent_text = sentence.text print(colored("Sentence:", 'red'), sent_text, "\n") print(colored("Semantic role labels:", 'red')) for link in pack.get(PredicateLink, sentence): parent = link.get_parent() child = link.get_child() print(f" - \"{child.text}\" is role {link.arg_type} of " f"predicate \"{parent.text}\"") print() input(colored("Press ENTER to continue...\n", 'green'))