def _process(self, input_pack: MultiPack): max_len = self.config.max_seq_length query_pack_name = self.config.query_pack_name query_pack = input_pack.get_pack(self.config.query_pack_name) query_entry = list(query_pack.get(Query))[0] query_text = query_pack.text packs = {} for doc_id in input_pack.pack_names: if doc_id == query_pack_name: continue pack = input_pack.get_pack(doc_id) document_text = pack.text # BERT Inference input_ids, segment_ids, input_mask = [ torch.LongTensor(item).unsqueeze(0).to(self.device) for item in self.tokenizer.encode_text(query_text, document_text, max_len) ] seq_length = (input_mask == 1).sum(dim=-1) logits, _ = self.model(input_ids, seq_length, segment_ids) preds = torch.nn.functional.softmax(torch.Tensor(logits), dim=1) score = preds.detach().tolist()[0][1] query_entry.update_results({doc_id: score}) packs[doc_id] = pack
def _process(self, input_pack: MultiPack): max_len = self.config.max_seq_length query_pack_name = self.config.query_pack_name query_pack = input_pack.get_pack(self.config.query_pack_name) query_entry = list(query_pack.get(Query))[0] query_text = query_pack.text #print(input_pack.pack_ids) #print(type(list(query_entry.results.values())[0])) print(" I am here finally", self.device) packs = {} #print(query_entry, 'Here', query_pack.get(Query)) #print(query_entry, "Before") for doc_id in input_pack.pack_names: if doc_id == query_pack_name: continue pack = input_pack.get_pack(doc_id) document_text = pack.text doc_id_final = pack.pack_name # ## BERT Inference encodings = self.tokenizer(query_text, document_text, padding = True, model_max_length=max_len, return_tensors= 'pt').to(self.device) # model.eval() with torch.no_grad(): logits = self.model(**encodings) pt_predictions = torch.nn.functional.softmax(logits[0], dim=1) score = pt_predictions.tolist()[0][1] query_entry.update_results({doc_id_final: score}) packs[doc_id] = pack #print(query_entry, "After")
def _process(self, input_pack: MultiPack): context_list = list() doc_id_list = list() for doc_id in input_pack.pack_names: if doc_id == self.configs.question_pack_name: continue pack = input_pack.get_pack(doc_id) context_list.append(pack.get_single(self.configs.entry_type).text) doc_id_list.append(doc_id) question_pack = input_pack.get_pack(self.configs.question_pack_name) first_question = question_pack.get_single(Sentence) question_list = [question_pack.text for i in range(len(context_list))] result_collection = self.extractor( context=context_list, question=question_list, max_answer_len=self.configs.max_answer_len, handle_impossible_answer=self.configs.handle_impossible_answer, ) for i, result in enumerate(result_collection): start = result["start"] end = result["end"] doc_pack = input_pack.get_pack(doc_id_list[i]) ans_phrase = Phrase(pack=doc_pack, begin=start, end=end) input_pack.add_entry( MultiPackLink(input_pack, first_question, ans_phrase))
def pack(self, data_pack: MultiPack, output_dict): """ Write the prediction results back to datapack. If :attr:`_overwrite` is `True`, write the predicted ner to the original tokens. Otherwise, create a new set of tokens and write the predicted ner to the new tokens (usually use this configuration for evaluation.) """ assert output_dict is not None output_pack = data_pack.get_pack(self.output_pack_name) input_sent_tids = output_dict["input_sents_tids"] output_sentences = output_dict["output_sents"] text = output_pack.text input_pack = data_pack.get_pack(self.input_pack_name) for input_id, output_sentence in zip(input_sent_tids, output_sentences): offset = len(output_pack.text) sent = Sentence(output_pack, offset, offset + len(output_sentence)) text += output_sentence + "\n" input_sent = input_pack.get_entry(input_id) cross_link = MultiPackLink(data_pack, input_sent, sent) data_pack.add_entry(cross_link) # We may also consider adding two link with opposite directions # Here the unidirectional link indicates the generation dependency output_pack.set_text(text)
def _process(self, input_pack: MultiPack): pack_i = input_pack.get_pack('default') pack_j = input_pack.get_pack('duplicate') for ent_i, ent_j in zip(pack_i.get(EntityMention), pack_j.get(EntityMention)): link = CrossDocEntityRelation(input_pack, ent_i, ent_j) link.rel_type = 'coreference' input_pack.add_entry(link)
def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]): r""" This function calls the data augmentation ops and modifies the input in-place. This function only applies for replacement-based methods. The subclasses should override this function to implement other data augmentation methods, such as Easy Data Augmentation. Args: input_pack: The input MultiPack. aug_pack_names: The packs names for DataPacks to be augmented. """ replacement_op = create_class_with_kwargs( self.configs["data_aug_op"], class_args={ "configs": self.configs["data_aug_op_config"]["kwargs"] }, ) augment_entry = get_class(self.configs["augment_entry"]) for pack_name in aug_pack_names: data_pack: DataPack = input_pack.get_pack(pack_name) anno: Annotation for anno in data_pack.get(augment_entry): self._replace(replacement_op, anno)
def _process(self, input_pack: MultiPack): r"""Searches ElasticSearch indexer to fetch documents for a query. This query should be contained in the input multipack with name `self.config.query_pack_name`. This method adds new packs to `input_pack` containing the retrieved results. Each result is added as a `ft.onto.base_ontology.Document`. Args: input_pack: A multipack containing query as a pack. """ query_pack = input_pack.get_pack(self.config.query_pack_name) # ElasticSearchQueryCreator adds a Query entry to query pack. We now # fetch it as the first element. first_query: Query = query_pack.get_single(Query) results = self.index.search(first_query.value) hits = results["hits"]["hits"] for idx, hit in enumerate(hits): document = hit["_source"] first_query.add_result(document["doc_id"], hit["_score"]) pack: DataPack = input_pack.add_pack( f"{self.config.response_pack_name_prefix}_{idx}") pack.doc_id = document["doc_id"] content = document[self.config.field] pack.set_text(content) Document(pack=pack, begin=0, end=len(content))
def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]): replacement_op = create_class_with_kwargs( self.configs["data_aug_op"], class_args={ "configs": self.configs["data_aug_op_config"]["kwargs"] }) augment_entry = get_class(self.configs["augment_entry"]) for pack_name in aug_pack_names: data_pack: DataPack = input_pack.get_pack(pack_name) annotations = [] pos = [0] for anno in data_pack.get(augment_entry): if anno.text not in self.stopwords: annotations.append(anno) pos.append(anno.end) if len(annotations) > 0: for _ in range(ceil(self.configs['alpha'] * len(annotations))): src_anno = random.choice(annotations) _, replaced_text = replacement_op.replace(src_anno) insert_pos = random.choice(pos) if insert_pos > 0: replaced_text = " " + replaced_text else: replaced_text = replaced_text + " " self._insert(replaced_text, data_pack, insert_pos)
def _get_data_batch( self, multi_pack: MultiPack, context_type: Type[Annotation], requests: Optional[Dict[Type[Entry], Union[Dict, List]]] = None, offset: int = 0, ) -> Iterable[Tuple[Dict, int]]: r"""Try to get batches of size ``batch_size``. If the tail instances cannot make up a full batch, will generate a small batch with the tail instances. Returns: An iterator of tuples ``(batch, cnt)``, ``batch`` is a dict containing the required annotations and context, and ``cnt`` is the number of instances in the batch. """ input_pack = multi_pack.get_pack(self.input_pack_name) instances: List[Dict] = [] current_size = sum(self.current_batch_sources) for data in input_pack.get_data(context_type, requests, offset): instances.append(data) if len(instances) == self.batch_size - current_size: batch = batch_instances(instances) self.batch_is_full = True yield (batch, len(instances)) instances = [] self.batch_is_full = False if len(instances): batch = batch_instances(instances) yield (batch, len(instances))
def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]): augment_entry = get_class(self.configs["augment_entry"]) if not issubclass(augment_entry, Annotation): raise ValueError(f"This augmenter only accept data of " f"'forte.data.ontology.top.Annotation' type, " f"but {self.configs['augment_entry']} is not.") for pack_name in aug_pack_names: data_pack: DataPack = input_pack.get_pack(pack_name) annotations: List[Annotation] = list(data_pack.get(augment_entry)) if len(annotations) > 0: replace_map: Dict = {} for _ in range(ceil(self.configs["alpha"] * len(annotations))): swap_idx = random.sample(range(len(annotations)), 2) new_idx_0 = (swap_idx[1] if swap_idx[1] not in replace_map else replace_map[swap_idx[1]]) new_idx_1 = (swap_idx[0] if swap_idx[0] not in replace_map else replace_map[swap_idx[0]]) replace_map[swap_idx[0]] = new_idx_0 replace_map[swap_idx[1]] = new_idx_1 pid: int = data_pack.pack_id for idx in replace_map: self._replaced_annos[pid].add(( annotations[idx].span, annotations[replace_map[idx]].text, ))
def _process(self, input_pack: MultiPack): r"""Search using Twitter API to fetch tweets for a query. This query should be contained in the input multipack with name `self.config.query_pack_name`. Each result is added as a new data pack, and a `ft.onto.base_ontology.Document` annotation is used to cover the whole document. Args: input_pack: A multipack containing query as a pack. """ query_pack = input_pack.get_pack(self.configs.query_pack_name) query = query_pack.text tweets = self._query_tweets(query) for idx, tweet in enumerate(tweets): try: text = tweet.retweeted_status.full_text except AttributeError: # Not a Retweet text = tweet.full_text pack: DataPack = input_pack.add_pack( f"{self.configs.response_pack_name_prefix}_{idx}") pack.pack_name = f"{self.configs.response_pack_name_prefix}_{idx}" pack.set_text(text) Document(pack=pack, begin=0, end=len(text))
def _process(self, input_pack: MultiPack): query = input_pack.get_pack(self.in_pack_name).text params = "?" + urlencode( { "api-version": "3.0", "from": self.src_language, "to": [self.target_language], }, doseq=True, ) microsoft_constructed_url = self.microsoft_translate_url + params response = requests.post( microsoft_constructed_url, headers=self.microsoft_headers, json=[{ "text": query }], ) if response.status_code != 200: raise RuntimeError(response.json()["error"]["message"]) text = response.json()[0]["translations"][0]["text"] pack: DataPack = input_pack.add_pack(self.out_pack_name) pack.set_text(text=text) Document(pack, 0, len(text)) Utterance(pack, 0, len(text))
def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]): augment_entry = get_class(self.configs["augment_entry"]) for pack_name in aug_pack_names: data_pack: DataPack = input_pack.get_pack(pack_name) for anno in data_pack.get(augment_entry): if random.random() < self.configs['alpha']: self._delete(anno)
def _process(self, input_pack: MultiPack): max_len = self.config.max_seq_length query_pack_name = self.config.query_pack_name batch_size = self.config.batch_size query_pack = input_pack.get_pack(self.config.query_pack_name) query_entry = list(query_pack.get(Query))[0] query_text = query_pack.text #print(input_pack.pack_ids) #print(type(list(query_entry.results.values())[0])) #print(query_entry, 'Here', query_pack.get(Query)) #print(query_entry, "Before") doc_text_list = [] doc_id_list = [] for doc_id in input_pack.pack_names: if doc_id == query_pack_name: continue pack = input_pack.get_pack(doc_id) document_text = pack.text doc_id_final = pack.pack_name doc_text_list.append(document_text) doc_id_list.append(doc_id_final) query_text_list = [query_text] * len(doc_text_list) # ## BERT Inference num_batches = int(len(doc_text_list)/batch_size) + (len(doc_text_list) % batch_size > 0) score_list = [] for i in range(0, num_batches): start = i * batch_size end = (i+1) * batch_size if(end > len(doc_text_list)): end = len(doc_text_list) encodings = self.tokenizer(query_text_list[start:end], doc_text_list[start:end], padding = True, model_max_length=max_len, return_tensors= 'pt').to(self.device) self.model.eval() with torch.no_grad(): logits = self.model(**encodings) pt_predictions = torch.nn.functional.softmax(logits[0], dim=1) scores = pt_predictions[:,1].tolist() score_list+=scores for doc_id_final, score in zip(doc_id_list, score_list): query_entry.update_results({doc_id_final: score}) #torch.cuda.empty_cache() #print(query_entry, "After")
def consume_next(self, pred_pack: MultiPack, _): #print(self.configs.pack_name) query_pack: DataPack = pred_pack.get_pack(self.configs.pack_name) query = list(query_pack.get(Query))[0] query_text = query_pack.text query_id = query_pack.pack_name #print(pred_pack.get_pack('passage_6').text) qa_results_dict = query.results # print("Printing") # for elem in self.predicted_results: # print(elem) for p_name in pred_pack.pack_names: if p_name!=self.configs.pack_name: passage_id = pred_pack.get_pack(p_name).pack_name if qa_results_dict[passage_id]: passage_text = pred_pack.get_pack(p_name).text answer_text = qa_results_dict[passage_id] self.predicted_results.append((query_id, query_text, passage_id, passage_text,answer_text))
def _process_query(self, input_pack: MultiPack) \ -> Tuple[DataPack, Dict[str, Any]]: query_pack: DataPack = input_pack.get_pack(self.config.query_pack_name) context = [query_pack.text] # use context to build the query if "user_utterance" in input_pack.pack_names: user_pack = input_pack.get_pack("user_utterance") context.append(user_pack.text) if "bot_utterance" in input_pack.pack_names: bot_pack = input_pack.get_pack("bot_utterance") context.append(bot_pack.text) text = ' '.join(context) query_vector = self._build_query(text=text) return query_pack, query_vector
def consume_next(self, pred_pack: MultiPack, _): query_pack: DataPack = pred_pack.get_pack(self.configs.pack_name) query = list(query_pack.get(Query))[0] rank = 1 for pid, _ in query.results.items(): doc_id: Optional[str] = query_pack.pack_name if doc_id is None: raise ProcessExecutionException( 'Doc ID of the query pack is not set, ' 'please double check the reader.') self.predicted_results.append((doc_id, pid, str(rank))) rank += 1
def _process(self, input_pack: MultiPack): max_len = self.config.max_seq_length query_pack_name = self.config.query_pack_name query_pack = input_pack.get_pack(self.config.query_pack_name) query_entry = list(query_pack.get(Query))[0] query_text = query_pack.text doc_score_dict = query_entry.results best_doc_id = max(doc_score_dict, key=lambda x: doc_score_dict[x]) packs = {} for doc_id in input_pack.pack_names: if doc_id == query_pack_name: continue pack = input_pack.get_pack(doc_id) doc_id_final = pack.pack_name if (doc_id_final != best_doc_id): query_entry.update_results({doc_id_final: ""}) continue query_doc_input = {'question': query_text, 'context': pack.text} result = self.qa_pipeline(query_doc_input) # answer = result['answer'] # Changing answer phrase to the whole sentence where it is present # print("====Full Passage Text: ", pack.text) # print("====Answer Phrase: ", result['answer']) answer_phrase = result['answer'] ans_sents = [sent.text for sent in spacy_nlp(pack.text).sents] # print("====Passage Sentences: ", ans_sents) answer = None for sent in ans_sents: if answer_phrase in sent: answer = sent break if not answer: answer = answer_phrase # print("====Final Answer Sentence: ", answer) query_entry.update_results({doc_id_final: answer}) packs[doc_id] = pack
def _process(self, input_pack: MultiPack): from_pack: DataPack = input_pack.get_pack(self.configs.copy_from) copy_pack: DataPack = input_pack.add_pack(self.configs.copy_to) copy_pack.set_text(from_pack.text) if from_pack.pack_name is not None: copy_pack.pack_name = from_pack.pack_name + '_copy' else: copy_pack.pack_name = 'copy' ent: EntityMention for ent in from_pack.get(EntityMention): EntityMention(copy_pack, ent.begin, ent.end)
def _process(self, input_pack: MultiPack): from_pack: DataPack = input_pack.get_pack(self.configs.copy_from) copy_pack: DataPack = input_pack.add_pack(self.configs.copy_to) copy_pack.set_text(from_pack.text) if from_pack.pack_name is not None: copy_pack.pack_name = from_pack.pack_name + '_copy' else: copy_pack.pack_name = 'copy' s: Sentence for s in from_pack.get(Sentence): Sentence(copy_pack, s.begin, s.end)
def _process(self, input_pack: MultiPack): query_pack = input_pack.get_pack(self.configs.query_pack_name) first_query = list(query_pack.get(Query))[0] results = self.index.search(first_query.value, self.k) documents = [r[1] for result in results for r in result] packs = {} for i, doc in enumerate(documents): pack = input_pack.add_pack() pack.set_text(doc) Document(pack, 0, len(doc)) packs[self.configs.response_pack_name_prefix + f"_{i}"] = pack input_pack.update_pack(packs)
def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]): r""" This function splits a given word at a random position and replaces the original word with 2 split parts of it. """ augment_entry = get_class(self.configs["augment_entry"]) for pack_name in aug_pack_names: data_pack: DataPack = input_pack.get_pack(pack_name) annotations: List[Annotation] = [] indexes: List[int] = [] endings = [] annos: Iterable[Annotation] = data_pack.get(augment_entry) for idx, anno in enumerate(annos): annotations.append(anno) indexes.append(idx) endings.append(anno.end) if len(annotations) > 0: annotation_to_split = random.sample( [(anno, idx) for (anno, idx) in zip(annotations, indexes) if (anno.end - anno.begin) > 1], ceil(self.configs["alpha"] * len(annotations)), ) annotation_to_split = sorted(annotation_to_split, key=lambda x: x[1], reverse=True) for curr_anno in annotation_to_split: src_anno, src_idx = curr_anno splitting_position = random.randrange( 1, (src_anno.end - src_anno.begin)) word_split = [ src_anno.text[:splitting_position], src_anno.text[splitting_position:], ] if src_idx != 0: first_position = endings[src_idx - 1] + 1 second_position = endings[src_idx] word_split[1] = " " + word_split[1] else: first_position = 0 second_position = endings[0] word_split[1] = " " + word_split[1] self._insert(word_split[1], data_pack, second_position) self._delete(src_anno) self._insert(word_split[0], data_pack, first_position)
def _process(self, input_pack: MultiPack): from_pack: DataPack = input_pack.get_pack(self.configs.copy_from) copy_pack: DataPack = input_pack.add_pack(self.configs.copy_to) copy_pack.set_text(from_pack.text) if from_pack.pack_name is not None: copy_pack.pack_name = from_pack.pack_name + "_copy" else: copy_pack.pack_name = "copy" s: Sentence for s in from_pack.get(Sentence): Sentence(copy_pack, s.begin, s.end) e: EntityMention for e in from_pack.get(EntityMention): EntityMention(copy_pack, e.begin, e.end)
def _process(self, input_pack: MultiPack): r"""Searches `Elasticsearch` indexer to fetch documents for a query. This query should be contained in the input multipack with name `self.config.query_pack_name`. This method adds new packs to `input_pack` containing the retrieved results. Each result is added as a `ft.onto.base_ontology.Document`. Args: input_pack: A multipack containing query as a pack. """ query_pack = input_pack.get_pack(self.configs.query_pack_name) # ElasticSearchQueryCreator adds a Query entry to query pack. We now # fetch it as the first element. first_query: Query = query_pack.get_single(Query) # pylint: disable=isinstance-second-argument-not-valid-type # TODO: until fix: https://github.com/PyCQA/pylint/issues/3507 if not isinstance(first_query.value, Dict): raise ValueError( "The query to the elastic indexer need to be a dictionary.") results = self.index.search(first_query.value) hits = results["hits"]["hits"] for idx, hit in enumerate(hits): document = hit["_source"] first_query.add_result(document["doc_id"], hit["_score"]) if self.configs.indexed_text_only: pack: DataPack = input_pack.add_pack( f"{self.configs.response_pack_name_prefix}_{idx}") pack.pack_name = document["doc_id"] content = document[self.configs.field] pack.set_text(content) Document(pack=pack, begin=0, end=len(content)) else: pack = DataPack.deserialize(document["pack_info"]) input_pack.add_pack_( pack, f"{self.configs.response_pack_name_prefix}_{idx}") pack.pack_name = document["doc_id"]
def consume_next(self, pred_pack: MultiPack, _): #print(self.configs.pack_name) query_pack: DataPack = pred_pack.get_pack(self.configs.pack_name) query = list(query_pack.get(Query))[0] query_text = query_pack.text #print(pred_pack.get_pack('passage_6').text) sorted_query_results = sorted(list(query.results.items()), key=lambda x: x[1], reverse=True) rank = 1 for pid, _ in sorted_query_results: doc_id: Optional[str] = query_pack.pack_name if doc_id is None: raise ProcessExecutionException( 'Doc ID of the query pack is not set, ' 'please double check the reader.') self.predicted_results.append((doc_id, pid, str(rank))) rank += 1
def _augment(self, input_pack: MultiPack, aug_pack_names: List[str]): augment_entry = get_class(self.configs["augment_entry"]) for pack_name in aug_pack_names: data_pack: DataPack = input_pack.get_pack(pack_name) annotations = list(data_pack.get(augment_entry)) if len(annotations) > 0: replace_map: Dict = {} for _ in range(ceil(self.configs['alpha'] * len(annotations))): swap_idx = random.sample(range(len(annotations)), 2) new_idx_0 = swap_idx[1] if swap_idx[1] not in replace_map \ else replace_map[swap_idx[1]] new_idx_1 = swap_idx[0] if swap_idx[0] not in replace_map \ else replace_map[swap_idx[0]] replace_map[swap_idx[0]] = new_idx_0 replace_map[swap_idx[1]] = new_idx_1 pid: int = data_pack.pack_id for idx in replace_map: self._replaced_annos[pid]\ .add((annotations[idx].span, annotations[replace_map[idx]].text))
def _process(self, input_pack: MultiPack): query = input_pack.get_pack(self.in_pack_name).text params = '?' + urlencode( {'api-version': '3.0', 'from': self.src_language, 'to': [self.target_language]}, doseq=True) microsoft_constructed_url = self.microsoft_translate_url + params response = requests.post( microsoft_constructed_url, headers=self.microsoft_headers, json=[{"text": query}]) if response.status_code != 200: raise RuntimeError(response.json()['error']['message']) text = response.json()[0]["translations"][0]["text"] pack: DataPack = input_pack.add_pack(self.out_pack_name) pack.set_text(text=text) Document(pack, 0, len(text)) Utterance(pack, 0, len(text))
def _process(self, input_pack: MultiPack): # Get the pack names for augmentation. aug_pack_names: List[str] = [] # Check if the DataPack exists. for pack_name in self.configs["augment_pack_names"]["kwargs"].keys(): if pack_name in input_pack.pack_names: aug_pack_names.append(pack_name) if len(self.configs["augment_pack_names"]["kwargs"].keys()) == 0: # Augment all the DataPacks if not specified. aug_pack_names = list(input_pack.pack_names) self._augment(input_pack, aug_pack_names) new_packs: List[Tuple[str, DataPack]] = [] for aug_pack_name in aug_pack_names: new_pack_name: str = \ self.configs["augment_pack_names"]["kwargs"].get( aug_pack_name, "augmented_" + aug_pack_name ) data_pack = input_pack.get_pack(aug_pack_name) new_pack = self._auto_align_annotations( data_pack=data_pack, replaced_annotations=self._replaced_annos[ data_pack.meta.pack_id]) new_packs.append((new_pack_name, new_pack)) for new_pack_name, new_pack in new_packs: input_pack.add_pack_(new_pack, new_pack_name) # Copy the MultiPackLinks/MultiPackGroups for mpl in input_pack.get(MultiPackLink): self._copy_multi_pack_link_or_group(mpl, input_pack) for mpg in input_pack.get(MultiPackGroup): self._copy_multi_pack_link_or_group(mpg, input_pack) # Must be called after processing each multipack # to reset internal states. self._clear_states()
def _process_query( self, input_pack: MultiPack) -> Tuple[DataPack, Dict[str, Any]]: query_pack = input_pack.get_pack(self.configs.query_pack_name) query = self._build_query(text=query_pack.text) return query_pack, query