예제 #1
0
    def _calc_loss(self, input: ReaderBatch) -> torch.Tensor:
        args = self.args
        input = ReaderBatch(**move_to_device(input._asdict(), args.device))
        attn_mask = self.tensorizer.get_attn_mask(input.input_ids)
        questions_num, passages_per_question, _ = input.input_ids.size()

        if self.reader.training:
            # start_logits, end_logits, rank_logits = self.reader(input.input_ids, attn_mask)
            loss = self.reader(input.input_ids, attn_mask,
                               input.start_positions, input.end_positions,
                               input.answers_mask)

        else:
            # TODO: remove?
            with torch.no_grad():
                start_logits, end_logits, rank_logits = self.reader(
                    input.input_ids, attn_mask)

            loss = compute_loss(input.start_positions, input.end_positions,
                                input.answers_mask, start_logits, end_logits,
                                rank_logits, questions_num,
                                passages_per_question)
        if args.n_gpu > 1:
            loss = loss.mean()
        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps

        return loss
예제 #2
0
    def _calc_loss(self, input: ReaderBatch) -> torch.Tensor:
        cfg = self.cfg
        input = ReaderBatch(**move_to_device(input._asdict(), cfg.device))
        attn_mask = self.tensorizer.get_attn_mask(input.input_ids)
        questions_num, passages_per_question, _ = input.input_ids.size()

        if self.reader.training:
            # start_logits, end_logits, rank_logits = self.reader(input.input_ids, attn_mask)
            loss = self.reader(
                input.input_ids,
                attn_mask,
                input.start_positions,
                input.end_positions,
                input.answers_mask,
                use_simple_loss=getattr(cfg.train, "use_simple_loss", False),
                average_loss=getattr(cfg.train, "average_loss", False),
            )

        else:
            # TODO: remove?
            with torch.no_grad():
                start_logits, end_logits, rank_logits = self.reader(
                    input.input_ids, attn_mask)

            loss = compute_loss(
                input.start_positions,
                input.end_positions,
                input.answers_mask,
                start_logits,
                end_logits,
                rank_logits,
                questions_num,
                passages_per_question,
                use_simple_loss=getattr(cfg.train, "use_simple_loss", False),
                average=getattr(cfg.train, "average_loss", False),
            )
        if cfg.n_gpu > 1:
            loss = loss.mean()
        if cfg.train.gradient_accumulation_steps > 1:
            loss = loss / cfg.train.gradient_accumulation_steps

        return loss