def test_argsort(self): keys = [5, 4, 3, 2, 1] items = ["five", "four", "three", "two", "one"] items2 = ["e", "d", "c", "b", "a"] torch_keys = torch.LongTensor(keys) assert argsort(keys, items, items2) == [ list(reversed(items)), list(reversed(items2)), ] assert argsort(keys, items, items2, descending=True) == [items, items2] assert np.all(argsort(torch_keys, torch_keys)[0].numpy() == np.arange(1, 6))
def batchify(self, obs_batch, sort=False): batch = super().batchify(obs_batch, sort) valid_obs = [(i, ex) for i, ex in enumerate(obs_batch) if self.is_valid(ex)] valid_inds, exs = zip(*valid_obs) # TEXT xs, x_lens = None, None if any('support_vec' in ex for ex in exs): _xs = [ex.get('support_vec', self.EMPTY) for ex in exs] xs, x_lens = padded_tensor(_xs, self.NULL_IDX, self.use_cuda, fp16friendly=self.opt.get('fp16')) if sort: sort = False # now we won't sort on labels xs, x_lens, valid_inds, exs = argsort(x_lens, xs, x_lens, valid_inds, exs, descending=True) qs, q_lens = None, None if any('query_vec' in ex for ex in exs): _qs = [ex.get('query_vec', self.EMPTY) for ex in exs] qs, q_lens = padded_tensor(_qs, self.NULL_IDX, self.use_cuda, fp16friendly=self.opt.get('fp16')) if sort: sort = False # now we won't sort on labels qs, q_lens, valid_inds, exs = argsort(q_lens, qs, q_lens, valid_inds, exs, descending=True) batch.query_vec = qs batch.query_lengths = q_lens batch.supports_vec = xs batch.supports_lengths = x_lens return batch
def _rerank_candidates( self, reranker_outputs: List[Message], response_cands: List[str], response_cand_scores: torch.Tensor, rerank_for_class: str, ) -> Tuple[List[str], List[int]]: """ Re-rank the response candidates given reranker outputs and a strategy. Compute reranking differently according to specified rerank strategy :param reranker_outputs: outputs from reranker :param response_cands: list of response candidates :param response_cand_scores: tensor with scored response candidates from initial model :param rerank_for_class: The class (in the ML sense) we want to select for :return (candidates, indices): candidates: reranked list of response candidates. indices: list of indices into response_cands corresponding to re-rank order """ if self.reranker_strategy == 'hard_choice': predicted_class = [ (i, c) for i, c in enumerate(response_cands) if reranker_outputs[i]['text'] == rerank_for_class ] try: predicted_indices, predicted_class = [ list(l) for l in zip(*predicted_class) ] except ValueError: # none predicted predicted_indices, predicted_class = [], [] predicted_not_class = [ (i, c) for i, c in enumerate(response_cands) if reranker_outputs[i]['text'] != rerank_for_class ] try: predicted_not_indices, predicted_not_class = [ list(l) for l in zip(*predicted_not_class) ] except ValueError: predicted_not_indices, predicted_not_class = [], [] candidates = predicted_class + predicted_not_class indices = predicted_indices + predicted_not_indices elif self.reranker_strategy == 'sum_scores': rerank_scores = [ o['sorted_scores'][o['text_candidates'].index( rerank_for_class)] for o in reranker_outputs ] scores = [ rerank_scores[i].item() + response_cand_scores[i].item() for i in range(len(rerank_scores)) ] candidates, indices = argsort( scores, response_cands, # type: ignore list(range(len(response_cands))), # type: ignore descending=True, )[:2] elif self.reranker_strategy == 'reranker_score': rerank_scores = [ o['sorted_scores'][o['text_candidates'].index( rerank_for_class)].item() for o in reranker_outputs ] candidates, indices = argsort( rerank_scores, response_cands, # type: ignore list(range(len(response_cands))), # type: ignore descending=True, )[:2] elif self.reranker_strategy == 'none': candidates = response_cands indices = list(range(len(response_cands))) return candidates, indices # type: ignore