コード例 #1
0
    def pack(self, data_pack: MultiPack, output_dict):
        """
        Write the prediction results back to datapack. If :attr:`_overwrite`
        is `True`, write the predicted ner to the original tokens.
        Otherwise, create a new set of tokens and write the predicted ner
        to the new tokens (usually use this configuration for evaluation.)
        """
        assert output_dict is not None
        output_pack = data_pack.packs[self.output_pack_name]

        input_sent_tids = output_dict["input_sents_tids"]
        output_sentences = output_dict["output_sents"]

        text = output_pack.text
        input_pack = data_pack.packs[self.input_pack_name]
        for input_id, output_sentence in zip(input_sent_tids, output_sentences):
            offset = len(output_pack.text)
            sent = Sentence(output_pack, offset, offset + len(output_sentence))
            output_pack.add_entry(sent)
            text += output_sentence + "\n"

            input_sent = input_pack.get_entry(input_id)
            cross_link = MultiPackLink(
                data_pack,
                data_pack.subentry(self.input_pack_name, input_sent),
                data_pack.subentry(self.output_pack_name, sent),
            )
            data_pack.add_entry(cross_link)
            # We may also consider adding two link with opposite directions
            # Here the unidirectional link indicates the generation dependency
        output_pack.set_text(text)
コード例 #2
0
    def _process(self, input_pack: MultiPack):
        r"""Searches ElasticSearch indexer to fetch documents for a query. This
        query should be contained in the input multipack with name
        `self.config.query_pack_name`.

        This method adds new packs to `input_pack` containing the retrieved
        results. Each result is added as a `ft.onto.base_ontology.Document`.

        Args:
             input_pack: A multipack containing query as a pack.
        """
        query_pack = input_pack.get_pack(self.config.query_pack_name)

        # ElasticSearchQueryCreator adds a Query entry to query pack. We now
        # fetch it as the first element.
        first_query = list(query_pack.get_entries(Query))[0]
        results = self.index.search(first_query.value)
        hits = results["hits"]["hits"]
        packs = {}
        for idx, hit in enumerate(hits):
            document = hit["_source"]
            first_query.update_results({document["doc_id"]: hit["_score"]})
            pack = DataPack(doc_id=document["doc_id"])
            content = document[self.config.field]
            document = Document(pack=pack, begin=0, end=len(content))
            pack.add_entry(document)
            pack.set_text(content)
            packs[f"{self.config.response_pack_name_prefix}_{idx}"] = pack

        input_pack.update_pack(packs)
コード例 #3
0
    def _process(self, input_pack: MultiPack):
        query = input_pack.get_pack(self.in_pack_name).text
        params = '?' + urlencode(
            {
                'api-version': '3.0',
                'from': self.src_language,
                'to': [self.target_language]
            },
            doseq=True)
        microsoft_constructed_url = self.microsoft_translate_url + params

        response = requests.post(microsoft_constructed_url,
                                 headers=self.microsoft_headers,
                                 json=[{
                                     "text": query
                                 }])

        if response.status_code != 200:
            raise RuntimeError(response.json()['error']['message'])

        text = response.json()[0]["translations"][0]["text"]
        pack = DataPack()

        document = Document(pack, 0, len(text))
        utterance = Utterance(pack, 0, len(text))
        pack.add_entry(document)
        pack.add_entry(utterance)

        pack.set_text(text=text)
        input_pack.update_pack({self.out_pack_name: pack})
コード例 #4
0
    def _process(self, input_pack: MultiPack):
        input_ids = []
        segment_ids = []

        query_pack = input_pack.get_pack("pack")
        context = [query_pack.text]

        # use context to build the query
        if "user_utterance" in input_pack.pack_names:
            user_pack = input_pack.get_pack("user_utterance")
            context.append(user_pack.text)

        if "bot_utterance" in input_pack.pack_names:
            bot_pack = input_pack.get_pack("bot_utterance")
            context.append(bot_pack.text)

        for text in context:
            t = self.tokenizer.encode_text(text)
            input_ids.append(t[0])
            segment_ids.append(t[1])

        input_ids = torch.LongTensor(input_ids).to(self.device)
        segment_ids = torch.LongTensor(segment_ids).to(self.device)
        _, query_vector = self.get_embeddings(input_ids, segment_ids)
        query_vector = torch.mean(query_vector, dim=0, keepdim=True)
        query_vector = query_vector.cpu().numpy()
        query = Query(pack=query_pack, value=query_vector)
        query_pack.add_or_get_entry(query)
コード例 #5
0
ファイル: selector_test.py プロジェクト: mgupta1410/forte-1
 def setUp(self) -> None:
     self.data_pack1 = DataPack(doc_id="1")
     self.data_pack2 = DataPack(doc_id="2")
     self.data_pack3 = DataPack(doc_id="Three")
     self.multi_pack = MultiPack()
     self.multi_pack.add_pack(self.data_pack1, pack_name="pack1")
     self.multi_pack.add_pack(self.data_pack2, pack_name="pack2")
     self.multi_pack.add_pack(self.data_pack3, pack_name="pack_three")
コード例 #6
0
    def _process(self, input_pack: MultiPack):
        src_pack = input_pack.get_pack(self.configs.source_pack_name)

        instance: NLIPair
        for instance in src_pack.get(NLIPair):
            premise = instance.get_parent()
            hypo = instance.get_child()

            for i, (new_prem,
                    new_hypo) in enumerate(self.tweak_nli_text(premise, hypo)):
                pack = input_pack.add_pack(f"generated_{i}",
                                           input_pack.pack_name + f"_{i}")
                create_nli(pack, new_prem, new_hypo)
コード例 #7
0
    def _process(self, input_pack: MultiPack):
        query_pack = input_pack.get_pack(self.config.query_pack_name)
        first_query = list(query_pack.get_entries(Query))[0]
        results = self.index.search(first_query.value, self.k)
        documents = [r[1] for result in results for r in result]

        packs = {}
        for i, doc in enumerate(documents):
            pack = DataPack()
            document = Document(pack=pack, begin=0, end=len(doc))
            pack.add_entry(document)
            pack.set_text(doc)
            packs[self.config.response_pack_name[i]] = pack

        input_pack.update_pack(packs)
コード例 #8
0
ファイル: batchers.py プロジェクト: mgupta1410/forte-1
    def _get_data_batch(
            self,
            data_pack: MultiPack,
            context_type: Type[Annotation],
            requests: Optional[Dict[Type[Entry], Union[Dict, List]]] = None,
            offset: int = 0) -> Iterable[Tuple[Dict, int]]:
        r"""Try to get batches of size ``batch_size``. If the tail instances
        cannot make up a full batch, will generate a small batch with the tail
        instances.

        Returns:
            An iterator of tuples ``(batch, cnt)``, ``batch`` is a dict
            containing the required annotations and context, and ``cnt`` is
            the number of instances in the batch.
        """
        input_pack = data_pack.get_pack(self.input_pack_name)

        instances: List[Dict] = []
        for data in input_pack.get_data(context_type, requests, offset):
            instances.append(data)
            if len(instances) == self.batch_size:
                batch = batch_instances(instances)
                self.batch_is_full = True
                yield (batch, len(instances))
                instances = []
                self.batch_is_full = False

        if len(instances):
            batch = batch_instances(instances)
            self.last_batch = True
            yield (batch, len(instances))
コード例 #9
0
    def _process(self, input_pack: MultiPack):
        query_pack = input_pack.get_pack("pack")
        first_query = list(query_pack.get_entries(Query))[0]
        results = self.index.search(first_query.value, self.k)
        documents = [r[1] for result in results for r in result]

        packs = {}
        counter = 0
        for doc in documents:
            pack = DataPack()
            document = Document(pack=pack, begin=0, end=len(doc))
            pack.add_entry(document)
            pack.set_text(doc)
            packs[f"doc_{counter}"] = pack
            counter += 1

        input_pack.update_pack(packs)
コード例 #10
0
ファイル: selector.py プロジェクト: mgupta1410/forte-1
    def select(self, m_pack: MultiPack) -> Iterator[DataPack]:
        if len(m_pack.packs) == 0:
            raise ValueError("Multi-pack is empty")

        else:
            for name, pack in m_pack.iter_packs():
                if re.match(self.select_name, name):
                    yield pack
コード例 #11
0
    def _process_query(self, input_pack: MultiPack):
        query_pack = input_pack.get_pack(self.config.query_pack_name)
        context = [query_pack.text]

        # use context to build the query
        if "user_utterance" in input_pack.pack_names:
            user_pack = input_pack.get_pack("user_utterance")
            context.append(user_pack.text)

        if "bot_utterance" in input_pack.pack_names:
            bot_pack = input_pack.get_pack("bot_utterance")
            context.append(bot_pack.text)

        text = ' '.join(context)

        query_vector = self._build_query(text=text)

        return query_pack, query_vector
コード例 #12
0
ファイル: selector.py プロジェクト: mgupta1410/forte-1
    def select(self, m_pack: MultiPack) -> Iterator[DataPack]:
        matches = 0
        for name, pack in m_pack.iter_packs():
            if name == self.select_name:
                matches += 1
                yield pack

        if matches == 0:
            raise ValueError(f"pack name {self.select_name}"
                             f"not in the MultiPack")
コード例 #13
0
ファイル: selector_test.py プロジェクト: mgupta1410/forte-1
class SelectorTest(unittest.TestCase):

    def setUp(self) -> None:
        self.data_pack1 = DataPack(doc_id="1")
        self.data_pack2 = DataPack(doc_id="2")
        self.data_pack3 = DataPack(doc_id="Three")
        self.multi_pack = MultiPack()
        self.multi_pack.add_pack(self.data_pack1, pack_name="pack1")
        self.multi_pack.add_pack(self.data_pack2, pack_name="pack2")
        self.multi_pack.add_pack(self.data_pack3, pack_name="pack_three")

    def test_name_match_selector(self) -> None:
        selector = NameMatchSelector(select_name="pack1")
        packs = selector.select(self.multi_pack)
        doc_ids = ["1"]
        for doc_id, pack in zip(doc_ids, packs):
            self.assertEqual(doc_id, pack.meta.doc_id)

    def test_regex_name_match_selector(self) -> None:
        selector = RegexNameMatchSelector(select_name="^.*\\d$")
        packs = selector.select(self.multi_pack)
        doc_ids = ["1", "2"]
        for doc_id, pack in zip(doc_ids, packs):
            self.assertEqual(doc_id, pack.meta.doc_id)

    def test_first_pack_selector(self) -> None:
        selector = FirstPackSelector()
        packs = list(selector.select(self.multi_pack))
        self.assertEqual(len(packs), 1)
        self.assertEqual(packs[0].meta.doc_id, "1")

    def test_all_pack_selector(self) -> None:
        selector = AllPackSelector()
        packs = selector.select(self.multi_pack)
        doc_ids = ["1", "2", "Three"]
        for doc_id, pack in zip(doc_ids, packs):
            self.assertEqual(doc_id, pack.meta.doc_id)
コード例 #14
0
 def _process_query(self, input_pack: MultiPack) -> \
         Tuple[DataPack, Dict[str, Any]]:
     query_pack = input_pack.get_pack(self.config.query_pack_name)
     query = self._build_query(text=query_pack.text)
     return query_pack, query