コード例 #1
0
ファイル: agents.py プロジェクト: simplecoka/cortx
    def custom_evaluation(self, teacher_action: Message, labels,
                          model_response: Message):
        resp = model_response.get('text')
        if not resp:
            return

        if teacher_action['type'] == 'apicall' and resp.startswith(
                'apicall: '):
            gold = teacher_action['slots']
            slot_strs = resp[9:].split(' ; ')
            parsed = {}
            for slot_str in slot_strs:
                if ' = ' not in slot_str:
                    if slot_str != '':
                        # syntactically invalid generations should count against us
                        self.metrics.add('slot_p', AverageMetric(0))
                    continue
                name, value = slot_str.split(' = ')
                parsed[name] = value

            # slot precision
            for k, v in parsed.items():
                self.metrics.add('slot_p', AverageMetric(v == gold.get(k)))
            # slot recall
            for k, v in gold.items():
                self.metrics.add('slot_r', AverageMetric(v == parsed.get(k)))
        elif teacher_action['type'] == 'apiresp':
            delex_resp = self._delex(resp, teacher_action['slots'])
            delex_label = self._delex(labels[0], teacher_action['slots'])
            self.metrics.add('delex_bleu',
                             BleuMetric.compute(delex_resp, [delex_label]))
コード例 #2
0
    def test_uneven_macro_aggrevation(self):
        report1 = {
            'avg': AverageMetric(1, 1),
        }
        report2 = {
            'avg': AverageMetric(0, 1),
        }
        report3 = {
            'avg': AverageMetric(0, 1),
        }
        agg1 = aggregate_named_reports({
            'a': report1,
            'b': report2
        },
                                       micro_average=False)
        agg2 = aggregate_named_reports({
            'a': {},
            'c': report3
        },
                                       micro_average=False)

        agg = aggregate_unnamed_reports([agg1, agg2])
        assert agg1['avg'] == 0.5
        assert agg2['avg'] == 0.0
        assert agg['a/avg'] == 1.0
        assert agg['b/avg'] == 0.0
        assert agg['c/avg'] == 0.0
        assert agg['avg'] == 1.0 / 3
コード例 #3
0
    def compute_loss(self, batch, return_output=False):
        """
        Compute and return the loss for the given batch.

        Easily overridable for customized loss functions.

        If return_output is True, the full output from the call to self.model()
        is also returned, via a (loss, model_output) pair.
        """
        if batch.label_vec is None:
            raise ValueError('Cannot compute loss without a label.')
        model_output = self.model(*self._model_input(batch), ys=batch.label_vec)
        scores, preds, *_ = model_output
        score_view = scores.view(-1, scores.size(-1))
        loss = self.criterion(score_view, batch.label_vec.view(-1))
        loss = loss.view(scores.shape[:-1]).sum(dim=1)
        # save loss to metrics
        notnull = batch.label_vec.ne(self.NULL_IDX)
        target_tokens = notnull.long().sum(dim=-1)
        correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)

        self.record_local_metric('loss', AverageMetric.many(loss, target_tokens))
        self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens))
        self.record_local_metric(
            'token_acc', AverageMetric.many(correct, target_tokens)
        )
        # actually do backwards loss
        loss = loss.sum()
        loss /= target_tokens.sum()  # average loss per token
        if return_output:
            return (loss, model_output)
        else:
            return loss
コード例 #4
0
 def _get_hidden_losses(
         self,
         fwd_pass: ForwardPassOutputs) -> Tuple[torch.Tensor, torch.Tensor]:
     """
     Return the encoder and decoder hidden losses.
     """
     assert isinstance(self, TorchGeneratorAgent)
     # Code relies on methods
     enc_hidden_loss, enc_hidden_loss_per_example = self._get_component_hidden_loss(
         student_hidden_states=fwd_pass.student_hidden_states['encoder'],
         teacher_hidden_states=fwd_pass.teacher_hidden_states['encoder'],
         mask=fwd_pass.context_mask,
         num_tokens=fwd_pass.num_context_tokens,
         mapped_layers=self.mapped_enc_layers,
     )
     self.record_local_metric(
         'enc_hid_loss',
         AverageMetric.many(enc_hidden_loss_per_example,
                            fwd_pass.context_tokens_per_example),
     )
     dec_hidden_loss, dec_hidden_loss_per_example = self._get_component_hidden_loss(
         student_hidden_states=fwd_pass.student_hidden_states['decoder'],
         teacher_hidden_states=fwd_pass.teacher_hidden_states['decoder'],
         mask=fwd_pass.decoder_mask,
         num_tokens=fwd_pass.num_tokens,
         mapped_layers=self.mapped_dec_layers,
     )
     self.record_local_metric(
         'dec_hid_loss',
         AverageMetric.many(dec_hidden_loss_per_example,
                            fwd_pass.tokens_per_example),
     )
     return enc_hidden_loss, dec_hidden_loss
コード例 #5
0
 def __init__(self, slot_p: AverageType = None, slot_r: AverageType = None):
     if not isinstance(slot_p, AverageMetric) and slot_p is not None:
         slot_p = AverageMetric(slot_p)
     if not isinstance(slot_r, AverageMetric) and slot_r is not None:
         slot_r = AverageMetric(slot_r)
     self._slot_p = slot_p
     self._slot_r = slot_r
コード例 #6
0
    def __init__(
        self,
        teacher_slots: Dict[str, str],
        predicted_slots: Dict[str, str],
        prefixes: Optional[List] = None,
        shared: Dict[str, Any] = None,
    ) -> None:
        super().__init__(shared=shared)
        self.prefixes = prefixes if prefixes else []
        self.add_with_prefixes("jga", AverageMetric(teacher_slots == predicted_slots))
        if len(teacher_slots) > 0:
            self.add_with_prefixes(
                "jga_noempty", AverageMetric(teacher_slots == predicted_slots)
            )
        else:
            self.add_with_prefixes(
                "jga_empty", AverageMetric(teacher_slots == predicted_slots)
            )

        # precision
        for pred_slot_name, pred_value in predicted_slots.items():
            slot_p = AverageMetric(teacher_slots.get(pred_slot_name) == pred_value)
            self.add_with_prefixes("slot_p", slot_p)
            self.add_with_prefixes("slot_f1", SlotF1Metric(slot_p=slot_p))
        # recall
        for teacher_slot_name, teacher_value in teacher_slots.items():
            slot_r = AverageMetric(
                predicted_slots.get(teacher_slot_name) == teacher_value
            )
            self.add_with_prefixes("slot_r", slot_r)
            self.add_with_prefixes("slot_f1", SlotF1Metric(slot_r=slot_r))
コード例 #7
0
def goals_slots_helper(
        goals: List[Dict],
        turnDict: List[Dict]) -> Tuple[Tuple[int, int], Tuple[int, int]]:
    """
    Helper function to see how well the slot keys + slot values match between attempted
    API calls and goals.

    Output is precision, recall.
    """
    all_call_slots = {k: v for call in turnDict for k, v in call.items()}
    all_goal_slots = {k: v for goal in goals for k, v in goal.items()}
    goal_in_call = {
        k: v
        for k, v in all_call_slots.items()
        if all_goal_slots.get(k, "definitelyNotInValuexyz") == v
    }
    call_in_goal = {
        k: v
        for k, v in all_goal_slots.items()
        if all_call_slots.get(k, "definitelyNotInValuexyz") == v
    }

    print(goal_in_call, all_call_slots)

    return (
        AverageMetric(len(goal_in_call), len(all_call_slots)),
        AverageMetric(len(call_in_goal), len(all_goal_slots)),
    )
コード例 #8
0
    def compute_multiobj_metrics(
        self,
        char_loss: torch.Tensor,
        scores: torch.Tensor,
        label_inds: torch.LongTensor,
        prefix: str = '',
    ):
        """
        Compute multi-objective metrics to track performance..

        :param char_loss:
            character loss (non-averaged) for each batch item
        :param scores:
            scores for character candidates
        :param label_inds:
            indices of correct characters
        """
        prefix = f'{prefix}_' if prefix else ''
        batchsize = scores.size(0)
        _, ranks = scores.topk(1, 1, largest=True)
        ranks_m = []
        mrrs_m = []
        hits_m = []
        for b in range(batchsize):
            rank = (ranks[b] == label_inds[b]).nonzero()
            rank = rank.item() if len(rank) == 1 else (scores.size(1) - 1)
            ranks_m.append(1 + rank)
            mrrs_m.append(1.0 / (1 + rank))
            hits_m.append(int(rank == 0))
        self.record_local_metric(f'{prefix}rank', AverageMetric.many(ranks_m))
        self.record_local_metric(f'{prefix}hits@1', AverageMetric.many(hits_m))
        self.record_local_metric(f'{prefix}mrr', AverageMetric.many(mrrs_m))
        self.record_local_metric(
            f'{prefix}mean_character_loss', AverageMetric.many(char_loss)
        )
コード例 #9
0
ファイル: test_metrics.py プロジェクト: zhangnn520/ParlAI
 def test_micro_aggregation(self):
     report1 = {
         'avg': AverageMetric(3, 4),
         'sum': SumMetric(3),
         'fixed': FixedMetric(4),
         'global_avg': GlobalAverageMetric(3, 4),
     }
     report2 = {
         'avg': AverageMetric(1, 3),
         'sum': SumMetric(4),
         'fixed': FixedMetric(4),
         'global_avg': GlobalAverageMetric(1, 3),
     }
     agg = aggregate_named_reports({'a': report1, 'b': report2}, micro_average=True)
     assert agg['avg'] == 4.0 / 7
     assert agg['sum'] == 7
     assert agg['fixed'] == 4
     assert agg['global_avg'] in (report1['global_avg'], report2['global_avg'])
     # task level metrics
     assert agg['a/avg'] == 3.0 / 4
     assert agg['a/sum'] == 3
     assert agg['a/fixed'] == 4
     assert 'a/global_avg' not in agg
     assert agg['b/avg'] == 1.0 / 3
     assert agg['b/sum'] == 4
     assert agg['b/fixed'] == 4
     assert 'b/global_avg' not in agg
コード例 #10
0
    def compute_loss(self, batch, return_output=False):
        """
        Override TGA.compute_loss to ignore start token.
        """
        if batch.label_vec is None:
            raise ValueError('Cannot compute loss without a label.')
        model_output = self.model(*self._model_input(batch),
                                  ys=batch.label_vec)
        scores, preds, *_ = model_output

        if scores.size(1) != batch.label_vec.size(1):
            # ignore start
            scores = scores[:, 1:, :]
            preds = preds[:, 1:]

        score_view = scores.reshape(-1, scores.size(-1))
        loss = self.criterion(score_view, batch.label_vec.view(-1))
        loss = loss.view(scores.shape[:-1]).sum(dim=1)
        # save loss to metrics
        notnull = batch.label_vec.ne(self.NULL_IDX)
        target_tokens = notnull.long().sum(dim=-1)
        correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)

        self.record_local_metric('loss',
                                 AverageMetric.many(loss, target_tokens))
        self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens))
        self.record_local_metric('token_acc',
                                 AverageMetric.many(correct, target_tokens))
        # actually do backwards loss
        loss = loss.sum()
        loss /= target_tokens.sum()  # average loss per token
        if return_output:
            return (loss, model_output)
        else:
            return loss
コード例 #11
0
    def handle_user_utt(
            self, message: Message,
            prefix_stripped_text: str) -> Optional[Dict[str, Metric]]:
        """
        Grab slots out of the user utterance based on an exact match.
        """
        utterance = prefix_stripped_text

        def get_slots(utt, options):
            results = set()
            for option in options:
                if option in utt:
                    results.add(option)
            return results

        all_slot_values_here = get_slots(utterance, self.all_goal_slot_values)
        req_slot_values_here = get_slots(utterance,
                                         self.all_req_goal_slot_values)

        self.mentioned_all_slot_values |= all_slot_values_here
        self.mentioned_req_slot_values |= req_slot_values_here

        metrics = {}
        metrics["user_utt_avg_any_slot"] = AverageMetric(
            len(all_slot_values_here))
        metrics["user_utt_avg_req_slot"] = AverageMetric(
            len(req_slot_values_here))
        return metrics
コード例 #12
0
ファイル: hred.py プロジェクト: aaronmueller/for-submission
    def compute_loss(self, batch, return_output=False):
        """
        Compute and return the loss for the given batch.

        Easily overridable for customized loss functions.

        If return_output is True, the full output from the call to self.model()
        is also returned, via a (loss, model_output) pair.
        """
        if batch['label_1_vec'] is None or batch['label_2_vec'] is None:
            raise ValueError('Cannot compute loss without a label.')

        #TODO HRED: change model forward function to implement HRED 
        # model output: (output_1, output_2)
        model_output = self.model(*self._model_input(batch), ys1=batch.label_1_vec, ys2 = batch.label_2_vec)
        scores_1, preds_1, scores_2, preds_2,  *_ = model_output
        score_1_view = scores_1.view(-1, scores_1.size(-1))
        loss_1 = self.criterion(score_1_view, batch.label_1_vec.view(-1))
        loss_1 = loss_1.view(scores_1.shape[:-1]).sum(dim=1)

        score_2_view = scores_2.view(-1, scores_2.size(-1))
        loss_2 = self.criterion(score_2_view, batch.label_2_vec.view(-1))
        loss_2 = loss_2.view(scores_2.shape[:-1]).sum(dim=1)

        # save loss to metrics
        notnull_1 = batch.label_1_vec.ne(self.NULL_IDX)
        target_tokens_1 = notnull_1.long().sum(dim=-1)
        #target_tokens_1 = notnull_1.long().sum().item()
        correct_1 = ((batch.label_1_vec == preds_1) * notnull_1).sum(dim=-1)
        notnull_2 = batch.label_2_vec.ne(self.NULL_IDX)
        #target_tokens_2 = notnull_2.long().sum().item()
        target_tokens_2 = notnull_2.long().sum(dim=-1)
        correct_2 = ((batch.label_2_vec == preds_2) * notnull_2).sum(dim=-1)
        
        target_tokens = torch.cat((target_tokens_1, target_tokens_2))
        correct = torch.cat((correct_1, correct_2)) 
        total_losses = torch.cat((loss_1, loss_2))

        num_tokens = target_tokens.sum().item()
        

        self.record_local_metric('loss_1', AverageMetric.many(loss_1, target_tokens_1))
        self.record_local_metric('ppl_1', PPLMetric.many(loss_1, target_tokens_1))
        self.record_local_metric(
                    'token_acc_1', AverageMetric.many(correct_1, target_tokens_1)
                                )

        self.record_local_metric('loss_2', AverageMetric.many(loss_2, target_tokens_2))
        self.record_local_metric('ppl_2', PPLMetric.many(loss_2, target_tokens_2))
        self.record_local_metric(
                    'token_acc_2', AverageMetric.many(correct_2, target_tokens_2)
                                )
        loss = total_losses.sum()
        loss /= num_tokens # average loss per token

        if return_output:
            return (loss, model_output)
        else:
            return loss
コード例 #13
0
ファイル: test_metrics.py プロジェクト: swycha/ParlAI
    def test_macroaverage_additions(self):
        m1 = AverageMetric(1, 3)
        m2 = AverageMetric(3, 4)

        assert (m1 + m2) == AverageMetric(4, 7)
        assert MacroAverageMetric({
            'a': m1,
            'b': m2
        }) == 0.5 * (1.0 / 3 + 3.0 / 4)
コード例 #14
0
ファイル: test_train_model.py プロジェクト: magicye/ParlAI
 def report(self):
     if self.count == 0:
         # initial score
         return {'loss': AverageMetric(3)}
     elif self.count == 1:
         # don't save the second validation
         return {'loss': AverageMetric(4)}
     else:
         # do save the third validation
         return {'loss': AverageMetric(2)}
コード例 #15
0
 def handle_sys_utt(
         self, message: Message,
         prefix_stripped_text: str) -> Optional[Dict[str, Metric]]:
     count = 0
     for val in self.api_resp_slots.values():
         if val in prefix_stripped_text:
             count += 1
     result = {"pseudo_inform_allSysTurns": AverageMetric(count)}
     if len(self.api_resp_slots) > 0:
         result["pseudo_inform_postApiRespSysTurns"] = AverageMetric(count)
     return result
コード例 #16
0
    def get_episode_metrics(self) -> Optional[Dict[str, Metric]]:
        result = {
            "user_any_goal_slots_recall":
            AverageMetric(len(self.mentioned_all_slot_values),
                          len(self.all_goal_slot_values)),
            "user_req_goal_slots_recall":
            AverageMetric(len(self.mentioned_req_slot_values),
                          len(self.all_req_goal_slot_values)),
        }

        self.mentioned_all_slot_values = set()
        self.mentioned_req_slot_values = set()
        return result
コード例 #17
0
ファイル: blenderbot2.py プロジェクト: skywalker023/ParlAI
 def compute_loss(
     self,
     batch: Batch,
     return_output: bool = False
 ) -> Union[torch.Tensor, Tuple[torch.Tensor, Any]]:
     """
     Override Rag.compute_loss to add some additional metrics.
     """
     loss, output = super().compute_loss(batch, return_output=True)
     assert isinstance(self.model, BlenderBot2RagModel)
     if (KnowledgeAccessMethod(self.opt['knowledge_access_method']) is
             KnowledgeAccessMethod.CLASSIFY
             and self.model.has_query_generator()):
         _scores, _preds, enc_state, *_ = output
         _, _, input_turns_cnt, _, _ = enc_state
         retrieval_type = self.model.get_retrieval_type()
         assert isinstance(retrieval_type, torch.Tensor)
         if input_turns_cnt is not None:
             new_ret_type = torch.zeros(input_turns_cnt.size(0))
             offset = 0
             for i in range(input_turns_cnt.size(0)):
                 new_ret_type[i] = retrieval_type[offset]
                 offset += input_turns_cnt[i]
             retrieval_type = new_ret_type
         self.record_local_metric(
             'search_class',
             AverageMetric.many(
                 retrieval_type.eq(
                     RetrievalType.SEARCH.value).int().tolist(),
                 [1] * retrieval_type.size(0),
             ),
         )
         self.record_local_metric(
             'memory_class',
             AverageMetric.many(
                 retrieval_type.eq(
                     RetrievalType.MEMORY.value).int().tolist(),
                 [1] * retrieval_type.size(0),
             ),
         )
         self.record_local_metric(
             'none_class',
             AverageMetric.many(
                 retrieval_type.eq(RetrievalType.NONE.value).int().tolist(),
                 [1] * retrieval_type.size(0),
             ),
         )
     if return_output:
         return loss, output
     else:
         return loss
コード例 #18
0
    def compute_loss(self, batch, return_output=False):
        """
        Override from TorchGeneratorAgent
        Compute and return the loss for the given batch.

        Easily overridable for customized loss functions.

        If return_output is True, the full output from the call to self.model()
        is also returned, via a (loss, model_output) pair.
        """
        if batch.label_vec is None:
            raise ValueError('Cannot compute loss without a label.')

        bsz = batch.text_vec.size(0)
        world_cardinality = self.world_cardinality
        embedding_size = self.opt.get('embedding_size')
        encoder_states = self.model.encoder(*self._encoder_input(batch))

        enc_output = encoder_states[0].view(bsz, world_cardinality, -1,
                                            embedding_size).contiguous()
        enc_output_mask = encoder_states[1].view(bsz, world_cardinality,
                                                 -1).contiguous()
        encoder_states = (enc_output, enc_output_mask)

        scores, preds = self.model.selfconscious_decode_forced(
            encoder_states, batch.label_vec)
        model_output = (scores, preds, encoder_states)

        score_view = scores.view(-1, scores.size(-1))
        loss = self.criterion(score_view, batch.label_vec.view(-1))
        loss = loss.view(scores.shape[:-1]).sum(dim=1)
        # save loss to metrics
        notnull = batch.label_vec.ne(self.NULL_IDX)
        target_tokens = notnull.long().sum(dim=-1)
        correct = ((batch.label_vec == preds) * notnull).sum(dim=-1)

        self.record_local_metric('loss',
                                 AverageMetric.many(loss, target_tokens))
        self.record_local_metric('ppl', PPLMetric.many(loss, target_tokens))
        self.record_local_metric('token_acc',
                                 AverageMetric.many(correct, target_tokens))

        # actually do backwards loss
        loss = loss.sum()
        loss /= target_tokens.sum()  # average loss per token

        if return_output:
            return (loss, model_output)
        else:
            return loss
コード例 #19
0
ファイル: test_metrics.py プロジェクト: swycha/ParlAI
    def test_average_metric_additions(self):

        input_pairs_and_outputs = [
            ((2, 4), (1.5, 1), 0.7),
            (
                (torch.LongTensor([[[2]]]), torch.Tensor([4])),
                (torch.FloatTensor([1.5]), 1),
                0.7,
            ),
        ]
        for input1, input2, output in input_pairs_and_outputs:
            actual_output = (AverageMetric(input1[0], input1[1]) +
                             AverageMetric(input2[0], input2[1])).value()
            self.assertAlmostEqual(actual_output, output, places=6)
            self.assertIsInstance(actual_output, float)
コード例 #20
0
    def handle_api_call(self, message: Message,
                        api_call: Dict) -> Optional[Dict[str, Metric]]:
        if STANDARD_API_SCHEMAS in message["text"]:
            return  # Happens for API call groundingion, so it's fine
        if len(api_call) == 0:
            return
        if STANDARD_API_NAME_SLOT not in api_call:
            return {
                "apiCall_wellFormed": AverageMetric(0),
                "apiCall_hasSlotsButNoApiNameSlot_count": SumMetric(1),
            }
        method = api_call[STANDARD_API_NAME_SLOT]

        method_found = False
        if len(self.api_schemas) > 0:
            for schema in self.api_schemas:
                if method == schema.get(STANDARD_API_NAME_SLOT, ""):
                    method_found = True
                    check = api_call.keys()
                    required = set(schema.get(STANDARD_REQUIRED_KEY, []))
                    required.add(STANDARD_API_NAME_SLOT)
                    for req in required:
                        if req not in check:  # miissing required
                            return {
                                "apiCall_wellFormed": AverageMetric(0),
                                "apiCall_missingRequiredSlot_count":
                                SumMetric(1),
                            }
                    opt_count = 0
                    for opt in schema.get(STANDARD_OPTIONAL_KEY, []):
                        if opt in check:
                            opt_count += 1
                    if opt_count + len(required) != len(check):
                        # have extra APIs that are not
                        return {
                            "apiCall_wellFormed": AverageMetric(0),
                            "apiCall_hasExtraParams_count": SumMetric(1),
                        }
                    break
        if method_found:
            return {
                "apiCall_wellFormed": AverageMetric(1),
                "apiCall_wellFormed_count": SumMetric(1),
            }
        return {
            "apiCall_wellFormed": AverageMetric(0),
            "apiCall_methodDNE_count": SumMetric(1),
        }
コード例 #21
0
    def eval_step(self, batch):
        """
        Evaluate a single batch of examples.
        """
        if batch.text_vec is None:
            return

        self.model.eval()
        scores = self.score(batch)
        probs = F.softmax(scores, dim=1)
        _, prediction_id = torch.max(probs.float().cpu(), 1)
        preds = [self.class_list[idx] for idx in prediction_id]

        if batch.labels is None or self.opt['ignore_labels']:
            # interactive mode
            if self.opt.get('print_scores', False):
                preds = self._format_interactive_output(probs, prediction_id)
        else:
            labels = self._get_label_tensor(batch)
            loss = self.criterion(scores, labels)
            self.record_local_metric('loss', AverageMetric.many(loss))

            preds = [self.class_list[idx] for idx in prediction_id]
            labels = batch.labels

            if preds is not None and labels is not None:
                self._update_confusion_matrix(preds, labels)

        if self.opt.get('print_scores', False):
            return Output(preds, probs=probs.cpu())
        else:
            return Output(preds)
コード例 #22
0
 def get_episode_metrics(self) -> Optional[Dict[str, Metric]]:
     all_goals_hit, _, _ = goals_hit_helper(self.goals, self.api_turns)
     call_attempts = len(self.api_turns)
     return {
         "synthetic_task_success": all_goals_hit,
         "api_call_attempts": AverageMetric(call_attempts),
     }
コード例 #23
0
 def test_slot_f1_metric_addition(self):
     a = SlotF1Metric(slot_p=1)
     b = SlotF1Metric(slot_r=0)
     c = SlotF1Metric(slot_p=AverageMetric(numer=2, denom=3), slot_r=1)
     d = a + b + c
     # Slot P should be 3/4 = 0.75; slot R should be 1/2 = 0.5
     self.assertEqual(0.6, d.value())
コード例 #24
0
    def eval_step(self, batch):
        """
        Train on a single batch of examples.
        """
        if batch.text_vec is None:
            return

        self.model.eval()
        scores = self.score(batch)
        probs = F.softmax(scores, dim=1)
        if self.threshold is None:
            _, prediction_id = torch.max(probs.cpu(), 1)
        else:
            ref_prob = probs.cpu()[:, 0]
            # choose ref class if Prob(ref class) > threshold
            prediction_id = (ref_prob <= self.threshold).to(torch.int64)
        preds = [self.class_list[idx] for idx in prediction_id]

        if batch.labels is None or self.opt['ignore_labels']:
            # interactive mode
            if self.opt.get('print_scores', False):
                preds = self._format_interactive_output(probs, prediction_id)
        else:
            labels = self._get_labels(batch)
            loss = self.criterion(scores, labels)
            self.record_local_metric('loss', AverageMetric.many(loss))
            loss = loss.mean()
            self._update_confusion_matrix(batch, preds)

        if self.opt.get('print_scores', False):
            return Output(preds, probs=probs.cpu())
        else:
            return Output(preds)
コード例 #25
0
ファイル: classifier.py プロジェクト: vaibhavsagar9/ParlAI
    def train_step(self, batch):
        """
        Train on a single batch of examples.
        """
        if batch.text_vec is None:
            return Output()
        self.model.train()
        self.zero_grad()

        # Calculate loss
        labels = self._get_label_tensor(batch)
        scores = self.score(batch)
        loss = self.criterion(scores, labels)
        self.record_local_metric('loss', AverageMetric.many(loss))
        loss = loss.mean()
        self.backward(loss)
        self.update_params()

        # Get predictions
        _, prediction_id = torch.max(scores.float().cpu(), 1)
        preds = [self.class_list[idx] for idx in prediction_id]
        labels_field = self.get_labels_field(batch['observations'])
        labels_lst = self._get_labels(batch['observations'], labels_field)
        self._update_confusion_matrix(preds, labels_lst)

        return Output(preds)
コード例 #26
0
    def _get_prediction_loss(self,
                             fwd_pass: ForwardPassOutputs) -> torch.Tensor:
        """
        Calculate and return the KL loss on the teacher's prediction layer.

        Also record prediction-loss metrics.
        """
        assert isinstance(self, TorchGeneratorAgent)
        # Code relies on methods
        pred_loss = F.kl_div(
            F.log_softmax(fwd_pass.student_scores, dim=-1, dtype=torch.float),
            F.softmax(fwd_pass.teacher_scores, dim=-1, dtype=torch.float),
            reduction='none',
        ).type_as(fwd_pass.student_scores)
        pred_loss = pred_loss.sum(dim=-1) * fwd_pass.mask
        # Sum over dictionary
        self.record_local_metric(
            'pred_ppl',
            PPLMetric.many(pred_loss.sum(dim=-1), fwd_pass.tokens_per_example),
        )  # Sum over tokens
        self.record_local_metric(
            'pred_loss',
            AverageMetric.many(pred_loss.sum(dim=-1),
                               fwd_pass.tokens_per_example),
        )  # Sum over tokens
        pred_loss = pred_loss.sum() / fwd_pass.num_tokens
        return pred_loss
コード例 #27
0
ファイル: test_metrics.py プロジェクト: zhangnn520/ParlAI
 def test_unnamed_aggregation(self):
     report1 = {
         'avg': AverageMetric(3, 4),
         'sum': SumMetric(3),
         'fixed': FixedMetric(4),
         'global_avg': GlobalAverageMetric(3, 4),
     }
     report2 = {
         'avg': AverageMetric(1, 3),
         'sum': SumMetric(4),
         'fixed': FixedMetric(4),
         'global_avg': GlobalAverageMetric(1, 3),
     }
     agg = aggregate_unnamed_reports([report1, report2])
     assert agg['avg'] == 4.0 / 7
     assert agg['sum'] == 7
     assert agg['fixed'] == 4
     assert agg['global_avg'] == 4.0 / 7
コード例 #28
0
    def _get_train_preds(self, scores, label_inds, cands, cand_vecs):
        """
        Return predictions from training.
        """
        # TODO: speed these calculations up
        batchsize = scores.size(0)
        if self.rank_top_k > 0:
            _, ranks = scores.topk(min(self.rank_top_k, scores.size(1)),
                                   1,
                                   largest=True)
        else:
            _, ranks = scores.sort(1, descending=True)
        ranks_m = []
        mrrs_m = []
        for b in range(batchsize):
            rank = (ranks[b] == label_inds[b]).nonzero()
            rank = rank.item() if len(rank) == 1 else scores.size(1)
            ranks_m.append(1 + rank)
            mrrs_m.append(1.0 / (1 + rank))
        self.record_local_metric('rank', AverageMetric.many(ranks_m))
        self.record_local_metric('mrr', AverageMetric.many(mrrs_m))

        ranks = ranks.cpu()
        # Here we get the top prediction for each example, but do not
        # return the full ranked list for the sake of training speed
        preds = []
        for i, ordering in enumerate(ranks):
            if cand_vecs.dim() == 2:  # num cands x max cand length
                cand_list = cands
            elif cand_vecs.dim(
            ) == 3:  # batchsize x num cands x max cand length
                cand_list = cands[i]
            if len(ordering) != len(cand_list):
                # We may have added padded cands to fill out the batch;
                # Here we break after finding the first non-pad cand in the
                # ranked list
                for x in ordering:
                    if x < len(cand_list):
                        preds.append(cand_list[x])
                        break
            else:
                preds.append(cand_list[ordering[0]])

        return Output(preds)
コード例 #29
0
 def test_slot_f1_metric_inputs(self):
     slots_p_r_and_f1 = [
         (None, None, float("nan")),
         (None, AverageMetric(0.0), float("nan")),
         (AverageMetric(0.0), AverageMetric(0.0), float("nan")),
         (AverageMetric(1), AverageMetric(1), 1.0),
         (AverageMetric(1), AverageMetric(0), 0.0),
         (AverageMetric(0.25), AverageMetric(0.75), 0.375),
     ]
     for slot_p, slot_r, slot_f1 in slots_p_r_and_f1:
         actual_slot_f1 = SlotF1Metric(slot_p=slot_p, slot_r=slot_r).value()
         if isnan(slot_f1):
             self.assertTrue(isnan(actual_slot_f1))
         else:
             self.assertEqual(slot_f1, actual_slot_f1)
コード例 #30
0
    def _get_batch_train_metrics(self, scores):
        """
        Get fast metrics calculations if we train with batch candidates.

        Specifically, calculate accuracy ('train_accuracy'), average rank, and mean
        reciprocal rank.
        """
        batchsize = scores.size(0)
        # get accuracy
        targets = scores.new_empty(batchsize).long()
        targets = torch.arange(batchsize, out=targets)
        nb_ok = (scores.max(dim=1)[1] == targets).float()
        self.record_local_metric('train_accuracy', AverageMetric.many(nb_ok))
        # calculate mean_rank
        above_dot_prods = scores - scores.diag().view(-1, 1)
        ranks = (above_dot_prods > 0).float().sum(dim=1) + 1
        mrr = 1.0 / (ranks + 0.00001)
        self.record_local_metric('rank', AverageMetric.many(ranks))
        self.record_local_metric('mrr', AverageMetric.many(mrr))