コード例 #1
0
    def test_argsort(self):
        keys = [5, 4, 3, 2, 1]
        items = ["five", "four", "three", "two", "one"]
        items2 = ["e", "d", "c", "b", "a"]
        torch_keys = torch.LongTensor(keys)
        assert argsort(keys, items, items2) == [
            list(reversed(items)),
            list(reversed(items2)),
        ]
        assert argsort(keys, items, items2, descending=True) == [items, items2]

        assert np.all(argsort(torch_keys, torch_keys)[0].numpy() == np.arange(1, 6))
コード例 #2
0
ファイル: mac_net.py プロジェクト: james-assiene/MacNetText
    def batchify(self, obs_batch, sort=False):

        batch = super().batchify(obs_batch, sort)

        valid_obs = [(i, ex) for i, ex in enumerate(obs_batch)
                     if self.is_valid(ex)]

        valid_inds, exs = zip(*valid_obs)

        # TEXT
        xs, x_lens = None, None
        if any('support_vec' in ex for ex in exs):
            _xs = [ex.get('support_vec', self.EMPTY) for ex in exs]
            xs, x_lens = padded_tensor(_xs,
                                       self.NULL_IDX,
                                       self.use_cuda,
                                       fp16friendly=self.opt.get('fp16'))
            if sort:
                sort = False  # now we won't sort on labels
                xs, x_lens, valid_inds, exs = argsort(x_lens,
                                                      xs,
                                                      x_lens,
                                                      valid_inds,
                                                      exs,
                                                      descending=True)

        qs, q_lens = None, None
        if any('query_vec' in ex for ex in exs):
            _qs = [ex.get('query_vec', self.EMPTY) for ex in exs]
            qs, q_lens = padded_tensor(_qs,
                                       self.NULL_IDX,
                                       self.use_cuda,
                                       fp16friendly=self.opt.get('fp16'))
            if sort:
                sort = False  # now we won't sort on labels
                qs, q_lens, valid_inds, exs = argsort(q_lens,
                                                      qs,
                                                      q_lens,
                                                      valid_inds,
                                                      exs,
                                                      descending=True)

        batch.query_vec = qs
        batch.query_lengths = q_lens
        batch.supports_vec = xs
        batch.supports_lengths = x_lens

        return batch
コード例 #3
0
    def _rerank_candidates(
        self,
        reranker_outputs: List[Message],
        response_cands: List[str],
        response_cand_scores: torch.Tensor,
        rerank_for_class: str,
    ) -> Tuple[List[str], List[int]]:
        """
        Re-rank the response candidates given reranker outputs and a strategy.

        Compute reranking differently according to specified rerank strategy

        :param reranker_outputs:
            outputs from reranker
        :param response_cands:
            list of response candidates
        :param response_cand_scores:
            tensor with scored response candidates from initial model
        :param rerank_for_class:
            The class (in the ML sense) we want to select for

        :return (candidates, indices):
            candidates: reranked list of response candidates.
            indices: list of indices into response_cands corresponding to re-rank order
        """
        if self.reranker_strategy == 'hard_choice':
            predicted_class = [
                (i, c) for i, c in enumerate(response_cands)
                if reranker_outputs[i]['text'] == rerank_for_class
            ]
            try:
                predicted_indices, predicted_class = [
                    list(l) for l in zip(*predicted_class)
                ]
            except ValueError:
                # none predicted
                predicted_indices, predicted_class = [], []

            predicted_not_class = [
                (i, c) for i, c in enumerate(response_cands)
                if reranker_outputs[i]['text'] != rerank_for_class
            ]
            try:
                predicted_not_indices, predicted_not_class = [
                    list(l) for l in zip(*predicted_not_class)
                ]
            except ValueError:
                predicted_not_indices, predicted_not_class = [], []

            candidates = predicted_class + predicted_not_class
            indices = predicted_indices + predicted_not_indices
        elif self.reranker_strategy == 'sum_scores':
            rerank_scores = [
                o['sorted_scores'][o['text_candidates'].index(
                    rerank_for_class)] for o in reranker_outputs
            ]
            scores = [
                rerank_scores[i].item() + response_cand_scores[i].item()
                for i in range(len(rerank_scores))
            ]
            candidates, indices = argsort(
                scores,
                response_cands,  # type: ignore
                list(range(len(response_cands))),  # type: ignore
                descending=True,
            )[:2]
        elif self.reranker_strategy == 'reranker_score':
            rerank_scores = [
                o['sorted_scores'][o['text_candidates'].index(
                    rerank_for_class)].item() for o in reranker_outputs
            ]
            candidates, indices = argsort(
                rerank_scores,
                response_cands,  # type: ignore
                list(range(len(response_cands))),  # type: ignore
                descending=True,
            )[:2]
        elif self.reranker_strategy == 'none':
            candidates = response_cands
            indices = list(range(len(response_cands)))

        return candidates, indices  # type: ignore