Пример #1
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                src = batch.src
                labels = batch.src_sent_labels
                segs = batch.segs
                clss = batch.clss
                mask = batch.mask_src
                mask_cls = batch.mask_cls

                sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)

                loss = self.loss(sent_scores, labels.float())
                loss = (loss * mask.float()).sum()
                batch_stats = Statistics(float(loss.cpu().data.numpy()),
                                         len(labels))
                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            return stats
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                src = batch.src
                labels = batch.src_sent_labels.float()
                segs = batch.segs
                clss = batch.clss
                mask = batch.mask_src
                mask_cls = batch.mask_cls

                sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)
                loss = self.loss(sent_scores, labels)
                loss = (loss * mask.float()).sum() / mask.float().sum()

                # baogs: report accuracy
                abs_scores, abs_ids = torch.topk(sent_scores, 3, dim=1)
                abs_mask = (abs_scores > 0).float()
                n_sents = abs_mask.sum().item()
                n_correct = torch.sum(
                    torch.gather(labels, 1, abs_ids) * abs_mask).item()
                batch_stats = Statistics(loss.item() * batch.batch_size,
                                         batch.batch_size, n_sents, n_correct)
                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            return stats
Пример #3
0
    def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """

        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s)) > 0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        can_path = '%s_step%d.candidate' % (self.args.result_path, step)
        gold_path = '%s_step%d.gold' % (self.args.result_path, step)
        all_preds = []
        with open(can_path, 'w') as save_pred:
            with open(gold_path, 'w') as save_gold:
                with torch.no_grad():
                    for batch in test_iter:
                        batch_size = batch.p_pair.size(0)
                        batch_stats, p_scores, n_scores = self._main(
                            batch, batch_size, is_train=False)
                        stats.update(batch_stats)

                        scores = []
                        preds = []
                        for i, idx in enumerate(p_scores):
                            p = p_scores[i].cpu().data.numpy()
                            n = n_scores[i].cpu().data.numpy()
                            scores.append(str(p) + '\t' + str(n))
                            preds.append(int(p > n))
                            all_preds.append(int(p > n))

                        for i in range(len(scores)):
                            save_gold.write(scores[i] + '\n')
                        for i in range(len(preds)):
                            save_pred.write(str(preds[i]) + '\n')
        print("**********************")
        print(sum(all_preds), len(all_preds))
        print("ACC: ", sum(all_preds) / float(len(all_preds)))
        print("**********************")
        self._report_step(0, step, valid_stats=stats)

        return stats
Пример #4
0
    def _gradient_accumulation(self, true_batchs, normalization, total_stats,
                               report_stats):
        total_loss = 0
        for batch in true_batchs:
            self.model.zero_grad()

            src = batch.src
            labels = batch.src_sent_labels
            segs = batch.segs
            clss = batch.clss
            mask = batch.mask_src
            mask_cls = batch.mask_cls

            sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)

            loss = self.loss(sent_scores, labels.float())
            loss = (loss * mask.float()).sum()
            (loss / loss.numel()).backward()
            # loss.div(float(normalization)).backward()

            batch_stats = Statistics(float(loss.cpu().data.numpy()),
                                     normalization)
            #
            total_stats.update(batch_stats)
            report_stats.update(batch_stats)

            # 4. Update the parameters and statistics.
            self.optim.step()
            total_loss += (loss / loss.numel()).item()

        return total_loss
Пример #5
0
    def _gradient_accumulation(self, true_batchs, normalization, total_stats,
                               report_stats):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            if self.grad_accum_count == 1:
                self.model.zero_grad()

            src = batch.src
            labels = batch.src_sent_labels

            segs = batch.segs
            clss = batch.clss
            mask = batch.mask_src

            mask_cls = batch.mask_cls
            src_txt = batch.src_txt

            sent_scores, mask = self.model(src, segs, clss, mask, mask_cls, src_txt)
            #pad = sent_scores.size()[1] - labels.size()[1]
            #labels_padded = torch.tensor([d + [0] * (pad) for d in labels])
            #print("LABELS : ", labels)
            #print("APRES LE FORWARD DE L'ENCODER : ")
            #print("size sent_scores :", sent_scores.size())
            #print("size mask :", mask.size())
            #print("size labels", labels.size())
            loss = self.loss(sent_scores, labels.float())
 #           print("size de LOSS :", loss.size())
            loss = (loss * mask_cls.float()).sum()
            (loss / loss.numel()).backward()
            # loss.div(float(normalization)).backward()

            batch_stats = Statistics(float(loss.cpu().data.numpy()), normalization)

            total_stats.update(batch_stats)
            report_stats.update(batch_stats)

            
            # 4. Update the parameters and statistics.
            if self.grad_accum_count == 1:
                # Multi GPU gradient gather
                if self.n_gpu > 1:
                    grads = [p.grad.data for p in self.model.parameters()
                             if p.requires_grad
                             and p.grad is not None]
                    distributed.all_reduce_and_rescale_tensors(
                        grads, float(1))
                self.optim.step()

        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.grad_accum_count > 1:
            if self.n_gpu > 1:
                grads = [p.grad.data for p in self.model.parameters()
                         if p.requires_grad
                         and p.grad is not None]
                distributed.all_reduce_and_rescale_tensors(
                    grads, float(1))
            self.optim.step()
Пример #6
0
    def _loss_cal(self, p_pred_score, n_pred_score, p_doc_highlight,
                  n_doc_highlight, normalization, is_train):
        # highligh distance loss
        y = -1 * torch.ones(p_doc_highlight.size(0)).cuda()
        hl_loss = self.loss(p_doc_highlight, n_doc_highlight, y)
        # topk loss
        p_topk = torch.topk(p_doc_highlight, k=30, dim=1)[0].sum()
        n_topk = torch.topk(n_doc_highlight, k=30, dim=1)[0].sum()
        topk_loss = (p_topk + n_topk) * 0.5
        # Predict scores difference
        p_loss = p_pred_score.sum()
        n_loss = n_pred_score.sum()

        pred_distance = p_loss - n_loss
        #dis_loss = torch.log(1 + torch.exp(-1 * (pred_distance-2)))
        loss = (-1000) * pred_distance + topk_loss + hl_loss
        if is_train:
            loss.backward()

        batch_stats = Statistics(float(loss.cpu().data.numpy()), \
                float(p_loss.cpu().data.numpy()), \
                float(n_loss.cpu().data.numpy()), \
                float(hl_loss.cpu().data.numpy()), \
                float(topk_loss.cpu().data.numpy()),\
                normalization)

        return batch_stats, p_pred_score, n_pred_score
Пример #7
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                batch_size = batch.p_pair.size(0)
                batch_stats, _, _ = self._main(batch,
                                               batch_size,
                                               is_train=False)
                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            return stats
    def _gradient_accumulation(self, true_batchs, normalization, total_stats,
                               report_stats):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            if self.grad_accum_count == 1:
                self.model.zero_grad()

            src = batch.src
            labels = batch.src_sent_labels.float()
            segs = batch.segs
            clss = batch.clss
            mask = batch.mask_src
            mask_cls = batch.mask_cls

            sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)
            loss = self.loss(sent_scores, labels)
            loss = (loss * mask.float()).sum() / mask.float().sum()
            loss.backward()

            # report accuracy
            abs_scores, abs_ids = torch.topk(sent_scores, 3, dim=1)
            abs_mask = (abs_scores > 0).float()
            n_sents = abs_mask.sum().item()
            n_correct = torch.sum(torch.gather(labels, 1, abs_ids) *
                                  abs_mask).item()
            batch_stats = Statistics(loss.item() * batch.batch_size,
                                     batch.batch_size, n_sents, n_correct)

            total_stats.update(batch_stats)
            report_stats.update(batch_stats)

            # 4. Update the parameters and statistics.
            if self.grad_accum_count == 1:
                # Multi GPU gradient gather
                if self.n_gpu > 1:
                    grads = [
                        p.grad.data for p in self.model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    distributed.all_reduce_and_rescale_tensors(grads, float(1))
                self.optim.step()

        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.grad_accum_count > 1:
            if self.n_gpu > 1:
                grads = [
                    p.grad.data for p in self.model.parameters()
                    if p.requires_grad and p.grad is not None
                ]
                distributed.all_reduce_and_rescale_tensors(grads, float(1))
            self.optim.step()
Пример #9
0
    def _gradient_accumulation(self, true_batchs, normalization, total_stats,
                               report_stats):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            if self.grad_accum_count == 1:
                self.model.zero_grad()

            src = batch.src
            labels = batch.src_sent_labels
            segs = batch.segs
            clss = batch.clss
            mask = batch.mask_src
            mask_cls = batch.mask_cls

            sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)

            if self.args.pairwise:
                loss = self.loss(sent_scores, labels.float(), mask)
                loss = loss.sum()
            else:
                loss = self.loss(sent_scores, labels.float())
                loss = (loss * mask.float()).sum()
            (loss / loss.numel()).backward()

            batch_stats = Statistics(float(loss.cpu().data.numpy()),
                                     normalization)

            total_stats.update(batch_stats)
            report_stats.update(batch_stats)

            # 4. Update the parameters and statistics.
            if self.grad_accum_count == 1:
                # Multi GPU gradient gather
                if self.n_gpu > 1:
                    grads = [
                        p.grad.data for p in self.model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    distributed.all_reduce_and_rescale_tensors(grads, float(1))
                self.optim.step()

        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.grad_accum_count > 1:
            if self.n_gpu > 1:
                grads = [
                    p.grad.data for p in self.model.parameters()
                    if p.requires_grad and p.grad is not None
                ]
                distributed.all_reduce_and_rescale_tensors(grads, float(1))
            self.optim.step()
Пример #10
0
 def _maybe_gather_stats(self, stat):
     """
     Gather statistics in multi-processes cases
     Args:
         stat(:obj:onmt.utils.Statistics): a Statistics object to gather
             or None (it returns None in this case)
     Returns:
         stat: the updated (or unchanged) stat object
     """
     if stat is not None and self.n_gpu > 1:
         return Statistics.all_gather_stats(stat)
     return stat
Пример #11
0
    def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """

        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s)) > 0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        can_path = '%s_step%d.candidate' % (self.args.result_path, step)
        gold_path = '%s_step%d.gold' % (self.args.result_path, step)
        with open(can_path, 'w') as save_pred:
            with open(gold_path, 'w') as save_gold:
                with torch.no_grad():
                    ct = 0
                    for batch in test_iter:
                        src = batch.src
                        labels = batch.src_sent_labels
                        segs = batch.segs
                        clss = batch.clss
                        mask = batch.mask_src
                        mask_cls = batch.mask_cls

                        gold = []
                        pred = []

                        if (cal_lead):
                            selected_ids = [list(range(batch.clss.size(1)))
                                            ] * batch.batch_size
                        elif (cal_oracle):
                            selected_ids = [[
                                j for j in range(batch.clss.size(1))
                                if labels[i][j] == 1
                            ] for i in range(batch.batch_size)]
                        else:
                            sent_scores, mask = self.model(
                                src, segs, clss, mask, mask_cls)

                            loss = self.loss(sent_scores, labels.float())
                            loss = (loss * mask.float()).sum()
                            batch_stats = Statistics(
                                float(loss.cpu().data.numpy()), len(labels))
                            stats.update(batch_stats)

                            sent_scores = sent_scores + mask.float()
                            sent_scores = sent_scores.cpu().data.numpy()
                            selected_ids = np.argsort(-sent_scores, 1)
                        # selected_ids = np.sort(selected_ids,1)
                        for i, idx in enumerate(selected_ids):
                            _pred = []
                            if (len(batch.src_str[i]) == 0):
                                continue
                            for j in selected_ids[i][:len(batch.src_str[i])]:
                                if (j >= len(batch.src_str[i])):
                                    continue
                                candidate = batch.src_str[i][j].strip()
                                if (self.args.block_trigram):
                                    if (not _block_tri(candidate, _pred)):
                                        _pred.append(candidate)
                                else:
                                    _pred.append(candidate)

                                if ((not cal_oracle)
                                        and (not self.args.recall_eval)
                                        and len(_pred) == 3):
                                    break

                            _pred = '<q>'.join(_pred)
                            if (self.args.recall_eval):
                                _pred = ' '.join(
                                    _pred.split()
                                    [:len(batch.tgt_str[i].split())])

                            pred.append(_pred)
                            gold.append(batch.tgt_str[i])

                        for i in range(len(gold)):
                            save_gold.write(str(ct) + '\n')
                            save_gold.write(gold[i].strip() + '\n')
                            save_pred.write(str(ct) + '\n')
                            save_pred.write(pred[i].strip() + '\n')
                            ct += 1
        if (step != -1 and self.args.report_rouge):
            rouges = test_rouge(self.args.temp_dir, can_path, gold_path)
            logger.info('Rouges at step %d \n%s' %
                        (step, rouge_results_to_str(rouges)))
        self._report_step(0, step, valid_stats=stats)

        return stats
Пример #12
0
    def train(self,
              train_iter_fct,
              train_steps,
              valid_iter_fct=None,
              valid_steps=-1):
        """
        The main training loops.
        by iterating over training data (i.e. `train_iter_fct`)
        and running validation (i.e. iterating over `valid_iter_fct`

        Args:
            train_iter_fct(function): a function that returns the train
                iterator. e.g. something like
                train_iter_fct = lambda: generator(*args, **kwargs)
            valid_iter_fct(function): same as train_iter_fct, for valid data
            train_steps(int):
            valid_steps(int):
            save_checkpoint_steps(int):

        Return:
            None
        """
        logger.info('Start training...')

        # step =  self.optim._step + 1
        step = self.optim._step + 1
        true_batchs = []
        accum = 0
        normalization = 0
        train_iter = train_iter_fct()

        total_stats = Statistics()
        report_stats = Statistics()
        self._start_report_manager(start_time=total_stats.start_time)

        while step <= train_steps:

            reduce_counter = 0
            for i, batch in enumerate(train_iter):
                if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank):

                    true_batchs.append(batch)
                    normalization += batch.batch_size
                    accum += 1
                    if accum == self.grad_accum_count:
                        reduce_counter += 1
                        if self.n_gpu > 1:
                            normalization = sum(
                                distributed.all_gather_list(normalization))

                        self._gradient_accumulation(true_batchs, normalization,
                                                    total_stats, report_stats)

                        report_stats = self._maybe_report_training(
                            step, train_steps, self.optim.learning_rate,
                            report_stats)

                        true_batchs = []
                        accum = 0
                        normalization = 0
                        if (step % self.save_checkpoint_steps == 0
                                and self.gpu_rank == 0):
                            self._save(step)

                        step += 1
                        if step > train_steps:
                            break
            train_iter = train_iter_fct()

        return total_stats
Пример #13
0
    def _gradient_accumulation(self, true_batchs, normalization, total_stats,
                               report_stats):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            if self.grad_accum_count == 1:
                self.model.zero_grad()

            src = batch.src
            labels = batch.src_sent_labels
            segs = batch.segs
            clss = batch.clss
            mask = batch.mask_src
            mask_cls = batch.mask_cls

            sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)



            # print("sent_scores ", sent_scores.size())
            # print(sent_scores)
            # print("labels ", labels.size())
            # print(labels)
            if self.args.pairwise:
                loss = self.loss(sent_scores, labels.float(), mask)
                # print("???")
                # with SummaryWriter(comment='model') as w:
                #     w.add_graph(self.loss, (sent_scores, labels.float(), mask, ) )
                #     print("1???")
                #     exit()
                loss = loss.sum()
            else:
                loss = self.loss(sent_scores, labels.float())
                loss = (loss * mask.float()).sum()
            # 做了个平均 numel返回number of elements
            (loss / loss.numel()).backward()


            # print("parameters: ")
            # paramss = list(self.model.named_parameters())
            # for each in paramss:
            #     try:
            #         if each[1].grad == None:
            #             print("f**k ", each[0])
            #     except:
            #         continue
            # exit()
                    # print("出现问题了, each[1] = ", each[1].grad)
            # for each in self.model.parameters():
            #     # if each.requires_grad == False:
            #     if each.grad == None:
            #         print(each.grad)
            # print("loss", loss.size())
            # print(loss)
            # print("mask", mask.size())
            # print(mask)
            # exit()
            # loss.div(float(normalization)).backward()

            batch_stats = Statistics(float(loss.cpu().data.numpy()), normalization)

            total_stats.update(batch_stats)
            report_stats.update(batch_stats)

            # 4. Update the parameters and statistics.
            if self.grad_accum_count == 1:
                # Multi GPU gradient gather
                if self.n_gpu > 1:
                    grads = [p.grad.data for p in self.model.parameters()
                             if p.requires_grad
                             and p.grad is not None]
                    distributed.all_reduce_and_rescale_tensors(
                        grads, float(1))
                self.optim.step()

        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.grad_accum_count > 1:
            if self.n_gpu > 1:
                grads = [p.grad.data for p in self.model.parameters()
                         if p.requires_grad
                         and p.grad is not None]
                distributed.all_reduce_and_rescale_tensors(
                    grads, float(1))
            self.optim.step()
Пример #14
0
    def _gradient_accumulation(self, true_batchs, normalization, total_stats,
                               report_stats):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            if self.grad_accum_count == 1:
                self.model.zero_grad()

            src = batch.src
            labels = batch.src_sent_labels
            segs = batch.segs
            clss = batch.clss
            mask = batch.mask_src
            mask_cls = batch.mask_cls

            if self.args.ext_sum_dec:
                sent_scores, mask = self.model(src, segs, clss, mask, mask_cls, labels)
                tgt_len = 3
                _, labels_id = torch.topk(labels, k=tgt_len)  # B, tgt_len
                labels_id, _= torch.sort(labels_id)
                # nsent 100 weight_up 20
                weight = torch.linspace(start=1, end=self.args.weight_up, steps=self.args.max_src_nsents).type_as(
                    sent_scores)
                # global max_class
                # max_class = max(max_class, torch.max(labels_id + 1).item())
                weight = weight[:sent_scores.size(-1)]
                # weight = torch.ones(self.args.max_src_nsents)
                loss = F.nll_loss(
                    F.log_softmax(
                        sent_scores.view(-1, sent_scores.size(-1)),
                        dim=-1,
                        dtype=torch.float32,
                    ),
                    labels_id.view(-1),  # bsz sent
                    weight=weight,
                    reduction='sum',
                    ignore_index=-1,
                )
                prediction = torch.argmax(sent_scores, dim=-1)
                if (self.optim._step + 1) % self.args.print_every == 0:
                    logger.info(
                        'train prediction: %s |label %s ' % (str(prediction), str(labels_id)))
                # both are numbers
                accuracy = torch.div(torch.sum(torch.equal(prediction, labels_id).float()), tgt_len)
            else:
                sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)
                loss = self.loss(sent_scores, labels.float())
                loss = (loss * mask.float()).sum()
                tgt_len = 3
                _, labels_id = torch.topk(labels, k=tgt_len)  # B, tgt_len
                labels_id, _ = torch.sort(labels_id)
                _, prediction = torch.topk(sent_scores, k=tgt_len)
                prediction, _ = torch.sort(labels_id)
                if (self.optim._step + 1) % self.args.print_every == 0:
                    logger.info(
                        'train prediction: %s |label %s ' % (str(prediction), str(labels_id)))
                accuracy = torch.div(torch.sum(torch.equal(prediction, labels_id).float()), tgt_len)
            (loss / loss.numel()).backward()
            # with amp.scale_loss((loss / loss.numel()), self.optim.optimizer) as scaled_loss:
            #     scaled_loss.backward()
            # loss.div(float(normalization)).backward()
            if self.args.acc_reporter:
                batch_stats = acc_reporter(float(loss.cpu().data.numpy()),accuracy, normalization)
            else:
                batch_stats = Statistics(float(loss.cpu().data.numpy()), normalization)

            total_stats.update(batch_stats)
            report_stats.update(batch_stats)

            # 4. Update the parameters and statistics.
            if self.grad_accum_count == 1:
                # Multi GPU gradient gather
                if self.n_gpu > 1:
                    grads = [p.grad.data for p in self.model.parameters()
                             if p.requires_grad
                             and p.grad is not None]
                    distributed.all_reduce_and_rescale_tensors(
                        grads, float(1))
                self.optim.step()

        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.grad_accum_count > 1:
            if self.n_gpu > 1:
                grads = [p.grad.data for p in self.model.parameters()
                         if p.requires_grad
                         and p.grad is not None]
                distributed.all_reduce_and_rescale_tensors(
                    grads, float(1))
            self.optim.step()
Пример #15
0
    def test1(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """

        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s)) > 0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        output = ''
        with torch.no_grad():
            for batch in test_iter:
                src = batch.src
                segs = batch.segs
                clss = batch.clss
                mask = batch.mask_src
                mask_cls = batch.mask_cls
                gold = []
                pred = []
                if (cal_lead):
                    selected_ids = [list(range(batch.clss.size(1)))
                                    ] * batch.batch_size
                elif (cal_oracle):
                    labels = batch.src_sent_labels
                    selected_ids = [[
                        j for j in range(batch.clss.size(1))
                        if labels[i][j] == 1
                    ] for i in range(batch.batch_size)]
                else:
                    # logger.info("src:%s, segs:%s, clss:%s, mask:%s, mask_cls:%s" % (
                    #    src, segs, clss, mask, mask_cls))
                    sent_scores, mask = self.model(src, segs, clss, mask,
                                                   mask_cls)

                    sent_scores = sent_scores + mask.float()
                    sent_scores = sent_scores.cpu().data.numpy()
                    selected_ids = np.argsort(-sent_scores, 1)

                    if (hasattr(batch, 'src_sent_labels')):
                        labels = batch.src_sent_labels
                        loss = self.loss(sent_scores, labels.float())
                        loss = (loss * mask.float()).sum()
                        batch_stats = Statistics(
                            float(loss.cpu().data.numpy()), len(labels))
                        stats.update(batch_stats)

                for i, idx in enumerate(selected_ids):
                    _pred = []
                    if (len(batch.src_str[i]) == 0):
                        continue
                    for j in selected_ids[i][:len(batch.src_str[i])]:
                        if (j >= len(batch.src_str[i])):
                            continue
                        candidate = batch.src_str[i][j].strip()
                        if (self.args.block_trigram):
                            if (not _block_tri(candidate, _pred)):
                                _pred.append(candidate)
                        else:
                            _pred.append(candidate)

                        if ((not cal_oracle) and (not self.args.recall_eval)
                                and len(_pred) == 3):
                            break

                    _pred = ' '.join(_pred)
                    if (self.args.recall_eval):
                        _pred = ' '.join(
                            _pred.split()[:len(batch.tgt_str[i].split())])

                    pred.append(_pred)
                    gold.append(batch.tgt_str[i])

        return ' '.join(pred)
Пример #16
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        if self.args.acc_reporter:
            stats = acc_reporter.Statistics()
        else:
            stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                # src = batch.src
                # labels = batch.src_sent_labels
                # segs = batch.segs
                # clss = batch.clss
                # mask = batch.mask_src
                # mask_cls = batch.mask_cls

                # sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)
                if self.args.jigsaw == 'jigsaw_lab':  # jigsaw_lab 3.31 23:38 发现之前忘了改validate, 早上起来再跑一次看看
                    logits = self.model(batch.src_s, batch.segs_s, batch.clss_s, batch.mask_src_s, batch.mask_cls_s)
                    # bsz, sent, max-sent_num
                    # mask = batch.mask_cls_s[:, :, None].float()
                    # loss = self.loss(sent_scores, batch.poss_s.float())
                    loss = F.nll_loss(
                        F.log_softmax(
                            logits.view(-1, logits.size(-1)),
                            dim=-1,
                            dtype=torch.float32,
                        ),
                        batch.poss_s.view(-1),  # bsz sent
                        reduction='sum',
                        ignore_index=-1,
                    )
                    prediction = torch.argmax(logits, dim=-1)
                    if (self.optim._step + 1) % self.args.print_every == 0:
                        logger.info(
                            'train prediction: %s |label %s ' % (str(prediction), str(batch.poss_s)))
                    accuracy = torch.div(torch.sum(torch.equal(prediction, batch.poss_s) * batch.mask_cls_s),
                                         torch.sum(batch.mask_cls_s)) * len(logits)
                elif self.args.jigsaw == 'jigsaw_dec':  # jigsaw decoder
                    poss_s = batch.poss_s
                    mask_poss = torch.eq(poss_s, -1)
                    poss_s.masked_fill_(mask_poss, 1e4)
                    # poss_s[i] [5,1,4,0,2,3,-1,-1]->[5,1,4,0,2,3,1e4,1e4]
                    dec_labels = torch.argsort(poss_s, dim=1)  # dec_labels[i] [3,1,xxx,6,7]
                    logits = self.model(batch.src_s, batch.segs_s, batch.clss_s, batch.mask_src_s, batch.mask_cls_s,
                                        dec_labels)
                    final_dec_labels = dec_labels.masked_fill(mask_poss, -1)  # final_dec_labels[i] [3,1,xxx,-1,-1]
                    loss = F.nll_loss(
                        F.log_softmax(
                            logits.view(-1, logits.size(-1)),
                            dim=-1,
                            dtype=torch.float32,
                        ),
                        final_dec_labels.view(-1),  # bsz sent
                        reduction='sum',
                        ignore_index=-1,
                    )

                    # loss = (loss * batch.mask_cls_s.float()).sum()
                    prediction = torch.argmax(logits, dim=-1)
                    if (self.optim._step + 1) % self.args.print_every == 0:
                        logger.info(
                            'train prediction: %s |label %s ' % (str(prediction), str(batch.poss_s)))
                    accuracy = torch.div(torch.sum(torch.equal(prediction, batch.final_dec_labels) * batch.mask_cls_s),
                                         torch.sum(batch.mask_cls_s)) * len(logits)


                # loss = self.loss(sent_scores, labels.float())
                # loss = (loss * mask.float()).sum()
                if self.args.acc_reporter:
                    batch_stats = acc_reporter.Statistics(float(loss.cpu().data.numpy()), accuracy, len(batch.poss_s))
                else:
                    batch_stats = Statistics(float(loss.cpu().data.numpy()), len(batch.poss_s))
                stats.update(batch_stats)

            self._report_step(0, step, valid_stats=stats)
            return stats
Пример #17
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        if self.args.acc_reporter == 1:
            stats = acc_reporter.Statistics()
        else:
            stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                src = batch.src
                labels = batch.src_sent_labels
                segs = batch.segs
                clss = batch.clss
                mask = batch.mask_src
                mask_cls = batch.mask_cls

                if self.args.ext_sum_dec:
                    sent_scores, mask = self.model(src, segs, clss, mask, mask_cls, labels)  # B, tgt_len custom_num
                    tgt_len = 3
                    _, labels_id = torch.topk(labels, k=tgt_len)  # B, tgt_len
                    labels_id, _ = torch.sort(labels_id)
                    # nsent 100 weight_up 20
                    weight = torch.linspace(start=1, end=self.args.weight_up, steps=self.args.max_src_nsents).type_as(sent_scores)
                    # self.max_class = max(self.max_class,torch.max(labels_id+1).item())
                    # weight = weight[:self.max_class]
                    weight = weight[:sent_scores.size(-1)]
                    # weight = torch.ones(self.args.max_src_nsents)
                    loss = F.nll_loss(
                        F.log_softmax(
                            sent_scores.view(-1, sent_scores.size(-1)),
                            dim=-1,
                            dtype=torch.float32,
                        ),
                        labels_id.view(-1),  # bsz sent
                        weight=weight,
                        reduction='sum',
                        ignore_index=-1,
                    )
                    prediction = torch.argmax(sent_scores, dim=-1)
                    if (self.optim._step + 1) % self.args.print_every == 0:
                        logger.info(
                            'train prediction: %s |label %s ' % (str(prediction), str(labels_id)))
                    accuracy = torch.div(torch.sum(torch.equal(prediction, labels_id).float()), tgt_len)
                else:
                    sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)  # B, custom_N
                    loss = self.loss(sent_scores, labels.float())
                    loss = (loss * mask.float()).sum()
                    tgt_len = 3
                    _, labels_id = torch.topk(labels, k=tgt_len)  # B, tgt_len
                    labels_id, _ = torch.sort(labels_id)
                    _, prediction = torch.topk(sent_scores, k=tgt_len)
                    prediction,_ = torch.sort(labels_id)
                    if (self.optim._step + 1) % self.args.print_every == 0:
                        logger.info(
                            'train prediction: %s |label %s ' % (str(prediction), str(labels_id)))
                    accuracy = torch.div(torch.sum(torch.equal(prediction, labels_id).float()), tgt_len)
                if self.args.acc_reporter == 1:
                    batch_stats = Statistics(float(loss.cpu().data.numpy()),accuracy, len(labels))
                else:
                    batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels))
                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            return stats
Пример #18
0
    def train(self,
              train_iter_fct,
              train_steps,
              valid_iter_fct=None,
              valid_steps=-1):
        """
        The main training loops.
        by iterating over training data (i.e. `train_iter_fct`)
        and running validation (i.e. iterating over `valid_iter_fct`

        Args:
            train_iter_fct(function): a function that returns the train
                iterator. e.g. something like
                train_iter_fct = lambda: generator(*args, **kwargs)
            valid_iter_fct(function): same as train_iter_fct, for valid data
            train_steps(int):
            valid_steps(int):
            save_checkpoint_steps(int):

        Return:
            None
        """
        logger.info('Start training...')
        print('Start training...')
        if self.model:
            self.model.train()
        step = self.optim._step + 1
        true_batchs = []
        normalization = 0
        train_iter = train_iter_fct()

        total_stats = Statistics()
        report_stats = Statistics()
        self._start_report_manager(start_time=total_stats.start_time)

        while step <= train_steps:
            reduce_counter = 0
            epoch_loss = 0
            for i, batch in enumerate(train_iter):
                true_batchs.append(batch)
                normalization += batch.batch_size
                reduce_counter += 1

                batch_loss = self._gradient_accumulation(
                    true_batchs, normalization, total_stats, report_stats)

                report_stats = self._maybe_report_training(
                    step, train_steps, self.optim.learning_rate, report_stats)

                true_batchs = []
                normalization = 0
                if step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0:
                    self._save(step)
                print(
                    f"batch {i} of epoch {step} is over with loss: {batch_loss}"
                )
                epoch_loss += batch_loss
                step += 1
                if step > train_steps:
                    break
            train_iter = train_iter_fct()
            print(f"Epoch {step-1} with loss: {epoch_loss}")
            break
        return total_stats
Пример #19
0
    def _gradient_accumulation(self, true_batchs, normalization, total_stats,
                               report_stats):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            if self.grad_accum_count == 1:
                self.model.zero_grad()
            # src = torch.tensor(self._pad(pre_src, 0))
            # segs = torch.tensor(self._pad(pre_segs, 0))
            # mask_src = torch.logical_not(src == 0)
            # clss = torch.tensor(self._pad(pre_clss, -1))
            # src_sent_labels = torch.tensor(self._pad(pre_src_sent_labels, 0))
            # mask_cls = torch.logical_not(clss == -1)
            # clss[clss == -1] = 0
            # setattr(self, 'clss' + postfix, clss.to(device))
            # setattr(self, 'mask_cls' + postfix, mask_cls.to(device))
            # setattr(self, 'src_sent_labels' + postfix, src_sent_labels.to(device))
            # setattr(self, 'src' + postfix, src.to(device))
            # setattr(self, 'segs' + postfix, segs.to(device))
            # setattr(self, 'mask_src' + postfix, mask_src.to(device))
            # # 下面都是要预测的给他pad -1, 意思是看到-1 就停止算loss, 不用计算mask ,mask 是作为输入时才要的
            # org_sent_labels = torch.tensor(self._pad(org_sent_labels, -1))
            # setattr(self, 'org_sent_labels' + postfix, org_sent_labels.to(device))
            # poss = torch.tensor(self._pad(poss, -1))
            # setattr(self, 'poss' + postfix, poss.to(device))

            if self.args.jigsaw == 'jigsaw_lab':  # jigsaw_lab 各自预测的那种,失败的尝试
                logits = self.model(batch.src_s, batch.segs_s, batch.clss_s, batch.mask_src_s, batch.mask_cls_s)# bsz tgt_len nsent
                # bsz, sent, max-sent_num
                # mask = batch.mask_cls_s[:, :, None].float()
                # loss = self.loss(sent_scores, batch.poss_s.float())
                loss = F.nll_loss(
                    F.log_softmax(
                        logits.view(-1, logits.size(-1)),
                        dim=-1,
                        dtype=torch.float32,
                    ),
                    batch.poss_s.view(-1), # bsz sent
                    reduction='sum',
                    ignore_index=-1,
                )
                prediction = torch.argmax(logits, dim=-1)
                if (self.optim._step + 1) % self.args.print_every == 0:
                    logger.info(
                        'train prediction: %s |label %s ' % (str(prediction), str(batch.poss_s)))
                accuracy = torch.div(torch.sum(torch.equal(prediction, batch.poss_s) * batch.mask_cls_s),
                                     torch.sum(batch.mask_cls_s)) * len(logits)

                # loss = (loss * batch.mask_cls_s.float()).sum()
                # print('train prediction: %s |label %s ' % (str(torch.argmax(logits, dim=-1)[0]), str(batch.poss_s[0])))
                # logger.info('train prediction: %s |label %s ' % (str(torch.argmax(logits, dim=-1)[0]), str(batch.poss_s[0])))
                # (loss / loss.numel()).backward()
            else:  #self.args.jigsaw == 'jigsaw_dec':    jigsaw decoder
                poss_s = batch.poss_s
                mask_poss = torch.eq(poss_s, -1)
                poss_s.masked_fill_(mask_poss, 1e4)
                # poss_s[i] [5,1,4,0,2,3,-1,-1]->[5,1,4,0,2,3,1e4,1e4] dec_labels[i] [3,1,xxx,6,7]
                dec_labels = torch.argsort(poss_s, dim=1)
                logits,_ = self.model(batch.src_s, batch.segs_s, batch.clss_s, batch.mask_src_s, batch.mask_cls_s, dec_labels)
                final_dec_labels = dec_labels.masked_fill(mask_poss, -1)
                loss = F.nll_loss(
                    F.log_softmax(
                        logits.view(-1, logits.size(-1)),
                        dim=-1,
                        dtype=torch.float32,
                    ),
                    final_dec_labels.view(-1),  # bsz sent
                    reduction='sum',
                    ignore_index=-1,
                )
                # loss = (loss * batch.mask_cls_s.float()).sum()
                # (loss / loss.numel()).backward()
                prediction = torch.argmax(logits, dim=-1)
                if (self.optim._step + 1) % self.args.print_every == 0:
                    logger.info(
                        'train prediction: %s |label %s ' % (str(prediction), str(batch.poss_s)))
                accuracy = torch.div(torch.sum(torch.equal(prediction, batch.poss_s) * batch.mask_cls_s),
                                     torch.sum(batch.mask_cls_s)) * len(logits)
            with amp.scale_loss((loss / loss.numel()), self.optim.optimizer) as scaled_loss:
                scaled_loss.backward()
            # loss.div(float(normalization)).backward()
            if self.args.acc_reporter:
                batch_stats = acc_reporter.Statistics(float(loss.cpu().data.numpy()), accuracy, normalization)
            else:
                batch_stats = Statistics(float(loss.cpu().data.numpy()), normalization)

            total_stats.update(batch_stats)
            report_stats.update(batch_stats)

            # 4. Update the parameters and statistics.
            if self.grad_accum_count == 1:
                # Multi GPU gradient gather
                if self.n_gpu > 1:
                    grads = [p.grad.data for p in self.model.parameters()
                             if p.requires_grad
                             and p.grad is not None]
                    distributed.all_reduce_and_rescale_tensors(
                        grads, float(1))
                self.optim.step()

        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.grad_accum_count > 1:
            if self.n_gpu > 1:
                grads = [p.grad.data for p in self.model.parameters()
                         if p.requires_grad
                         and p.grad is not None]
                distributed.all_reduce_and_rescale_tensors(
                    grads, float(1))
            self.optim.step()
Пример #20
0
    def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1):
        """
        The main training loops.
        by iterating over training data (i.e. `train_iter_fct`)
        and running validation (i.e. iterating over `valid_iter_fct`
        Args:
            train_iter_fct(function): a function that returns the train
                iterator. e.g. something like
                train_iter_fct = lambda: generator(*args, **kwargs)
            valid_iter_fct(function): same as train_iter_fct, for valid data
            train_steps(int):
            valid_steps(int):
            save_checkpoint_steps(int):
        Return:
            None
        """
        logger.info('Start training...')

        # step =  self.optim._step + 1
        step = self.optim._step + 1
        true_batchs = []
        accum = 0
        normalization = 0
        sent_num_normalization = 0
        train_iter = train_iter_fct()

        total_stats = Statistics()
        valid_global_stats = Statistics(stat_file_dir=self.args.model_path)
        valid_global_stats.write_stat_header(self.is_joint)
        # report_stats = Statistics(print_traj=self.is_joint)
        report_stats = Statistics(print_traj=self.is_joint)
        self._start_report_manager(start_time=total_stats.start_time)

        while step <= train_steps:
            reduce_counter = 0
            for i, batch in enumerate(train_iter):
                if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank):
                    true_batchs.append(batch)
                    normalization += batch.batch_size
                    accum += 1
                    if accum == self.grad_accum_count:
                        reduce_counter += 1
                        if self.n_gpu > 1:
                            normalization = sum(distributed
                                                .all_gather_list
                                                (normalization))

                        self._gradient_accumulation(
                            true_batchs, normalization, total_stats,
                            report_stats)

                        report_stats = self._maybe_report_training(
                            step, train_steps,
                            self.optim.learning_rate,
                            self.model.uncertainty_loss._sigmas_sq[0].item() if self.is_joint else 0,
                            self.model.uncertainty_loss._sigmas_sq[1].item() if self.is_joint else 0,
                            report_stats)

                        self._report_step(self.optim.learning_rate, step,
                                          self.model.uncertainty_loss._sigmas_sq[0] if self.is_joint else 0,
                                          self.model.uncertainty_loss._sigmas_sq[1] if self.is_joint else 0,
                                          train_stats=report_stats)

                        true_batchs = []
                        accum = 0
                        normalization = 0

                        if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0):
                            val_stat, best_model_save, best_recall_model_save = self.validate_rouge_baseline(
                                valid_iter_fct, step,
                                valid_gl_stats=valid_global_stats)
                            self._save(step, best=best_model_save, recall_model=best_recall_model_save, valstat=val_stat)

                        if step == 5 or step % self.args.val_interval == 0:  # Validation
                        # if step % self.args.val_interval == 0:  # Validation
                            logger.info('----------------------------------------')
                            logger.info('Start evaluating on evaluation set... ')
                            self.args.pick_top = True

                            val_stat, best_model_save, best_recall_model_save = self.validate_rouge_baseline(valid_iter_fct, step,
                                                                                     valid_gl_stats=valid_global_stats)


                            if best_model_save:
                                self._save(step, best=True, valstat=val_stat)
                                logger.info(f'Best model saved sucessfully at step %d' % step)
                                self.best_val_step = step

                            if best_recall_model_save:
                                self._save(step, best=True, valstat=val_stat, recall_model=True)
                                logger.info(f'Best model saved sucessfully at step %d' % step)
                                self.best_val_step = step

                            self.save_validation_results(step, val_stat)

                            logger.info('----------------------------------------')
                            self.model.train()

                        step += 1
                        if step > train_steps:
                            break
            train_iter = train_iter_fct()

        return total_stats
Пример #21
0
    def validate_rouge_baseline(self, valid_iter_fct, step=0, valid_gl_stats=None, write_scores_to_pickle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """

        preds = {}
        preds_with_idx = {}
        golds = {}
        can_path = '%s_step%d.source' % (self.args.result_path, step)
        gold_path = '%s_step%d.target' % (self.args.result_path, step)

        if step == self.best_val_step:
            can_path = '%s_step%d.source' % (self.args.result_path_test, step)
            gold_path = '%s_step%d.target' % (self.args.result_path_test, step)

        save_pred = open(can_path, 'w')
        save_gold = open(gold_path, 'w')
        sent_scores_whole = {}
        sent_sects_whole_pred = {}
        sent_sects_whole_true = {}
        sent_labels_true = {}
        sent_numbers_whole = {}
        paper_srcs = {}
        paper_tgts = {}
        sent_sect_wise_rg_whole = {}
        sent_sections_txt_whole = {}
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()
        best_model_saved = False
        best_recall_model_saved = False

        valid_iter = valid_iter_fct()

        with torch.no_grad():
            for batch in tqdm(valid_iter):
                src = batch.src
                labels = batch.src_sent_labels
                sent_labels = batch.sent_labels

                if self.rg_predictor:
                    sent_true_rg = batch.src_sent_labels
                else:
                    sent_labels = batch.sent_labels
                segs = batch.segs
                clss = batch.clss
                mask = batch.mask_src
                mask_cls = batch.mask_cls
                p_id = batch.paper_id
                segment_src = batch.src_str
                paper_tgt = batch.tgt_str
                sent_sect_wise_rg = batch.sent_sect_wise_rg
                sent_sections_txt = batch.sent_sections_txt
                sent_numbers = batch.sent_numbers

                sent_sect_labels = batch.sent_sect_labels
                if self.is_joint:
                    if not self.rg_predictor:
                        sent_scores, sent_sect_scores, mask, loss, loss_sent, loss_sect = self.model(src, segs, clss,
                                                                                                     mask, mask_cls,
                                                                                                     sent_labels,
                                                                                                     sent_sect_labels)
                    else:
                        sent_scores, sent_sect_scores, mask, loss, loss_sent, loss_sect = self.model(src, segs, clss,
                                                                                                     mask, mask_cls,
                                                                                                     sent_true_rg,
                                                                                                     sent_sect_labels)
                    acc, _ = self._get_mertrics(sent_sect_scores, sent_sect_labels, mask=mask,
                                                task='sent_sect')

                    batch_stats = Statistics(loss=float(loss.cpu().data.numpy().sum()),
                                             loss_sect=float(loss_sect.cpu().data.numpy().sum()),
                                             loss_sent=float(loss_sent.cpu().data.numpy().sum()),
                                             n_docs=len(labels),
                                             n_acc=batch.batch_size,
                                             RMSE=self._get_mertrics(sent_scores, labels, mask=mask, task='sent'),
                                             accuracy=acc)

                else:
                    if not self.rg_predictor:
                        sent_scores, mask, loss, _, _ = self.model(src, segs, clss, mask, mask_cls, sent_labels,
                                                                   sent_sect_labels=None, is_inference=True)
                    else:
                        sent_scores, mask, loss, _, _ = self.model(src, segs, clss, mask, mask_cls, sent_true_rg,
                                                                   sent_sect_labels=None, is_inference=True)

                    # sent_scores = (section_rg.unsqueeze(1).expand_as(sent_scores).to(device='cuda')*100) * sent_scores

                    batch_stats = Statistics(loss=float(loss.cpu().data.numpy().sum()),
                                             RMSE=self._get_mertrics(sent_scores, labels, mask=mask, task='sent'),
                                             n_acc=batch.batch_size,
                                             n_docs=len(labels))

                stats.update(batch_stats)

                sent_scores = sent_scores + mask.float()
                sent_scores = sent_scores.cpu().data.numpy()

                for idx, p_id in enumerate(p_id):
                    p_id = p_id.split('___')[0]

                    if p_id not in sent_scores_whole.keys():
                        masked_scores = sent_scores[idx] * mask[idx].cpu().data.numpy()
                        masked_scores = masked_scores[np.nonzero(masked_scores)]

                        masked_sent_labels_true = (sent_labels[idx] + 1) * mask[idx].long()

                        masked_sent_labels_true = masked_sent_labels_true[np.nonzero(masked_sent_labels_true)].flatten()
                        masked_sent_labels_true = (masked_sent_labels_true - 1)

                        sent_scores_whole[p_id] = masked_scores
                        sent_labels_true[p_id] = masked_sent_labels_true.cpu()

                        masked_sents_sections_true = (sent_sect_labels[idx] + 1) * mask[idx].long()

                        masked_sents_sections_true = masked_sents_sections_true[
                            np.nonzero(masked_sents_sections_true)].flatten()
                        masked_sents_sections_true = (masked_sents_sections_true - 1)
                        sent_sects_whole_true[p_id] = masked_sents_sections_true.cpu()

                        if self.is_joint:
                            masked_scores_sects = sent_sect_scores[idx] * mask[idx].view(-1, 1).expand_as(
                                sent_sect_scores[idx]).float()
                            masked_scores_sects = masked_scores_sects[torch.abs(masked_scores_sects).sum(dim=1) != 0]
                            sent_sects_whole_pred[p_id] = torch.max(self.softmax(masked_scores_sects), 1)[1].cpu()

                        paper_srcs[p_id] = segment_src[idx]
                        if sent_numbers[0] is not None:
                            sent_numbers_whole[p_id] = sent_numbers[idx]
                            # sent_tokens_count_whole[p_id] = sent_tokens_count[idx]
                        paper_tgts[p_id] = paper_tgt[idx]
                        sent_sect_wise_rg_whole[p_id] = sent_sect_wise_rg[idx]
                        sent_sections_txt_whole[p_id] = sent_sections_txt[idx]


                    else:
                        masked_scores = sent_scores[idx] * mask[idx].cpu().data.numpy()
                        masked_scores = masked_scores[np.nonzero(masked_scores)]

                        masked_sent_labels_true = (sent_labels[idx] + 1) * mask[idx].long()
                        masked_sent_labels_true = masked_sent_labels_true[np.nonzero(masked_sent_labels_true)].flatten()
                        masked_sent_labels_true = (masked_sent_labels_true - 1)

                        sent_scores_whole[p_id] = np.concatenate((sent_scores_whole[p_id], masked_scores), 0)
                        sent_labels_true[p_id] = np.concatenate((sent_labels_true[p_id], masked_sent_labels_true.cpu()),
                                                                0)

                        masked_sents_sections_true = (sent_sect_labels[idx] + 1) * mask[idx].long()
                        masked_sents_sections_true = masked_sents_sections_true[
                            np.nonzero(masked_sents_sections_true)].flatten()
                        masked_sents_sections_true = (masked_sents_sections_true - 1)
                        sent_sects_whole_true[p_id] = np.concatenate(
                            (sent_sects_whole_true[p_id], masked_sents_sections_true.cpu()), 0)

                        if self.is_joint:
                            masked_scores_sects = sent_sect_scores[idx] * mask[idx].view(-1, 1).expand_as(
                                sent_sect_scores[idx]).float()
                            masked_scores_sects = masked_scores_sects[
                                torch.abs(masked_scores_sects).sum(dim=1) != 0]
                            sent_sects_whole_pred[p_id] = np.concatenate(
                                (sent_sects_whole_pred[p_id], torch.max(self.softmax(masked_scores_sects), 1)[1].cpu()),
                                0)

                        paper_srcs[p_id] = np.concatenate((paper_srcs[p_id], segment_src[idx]), 0)
                        if sent_numbers[0] is not None:
                            sent_numbers_whole[p_id] = np.concatenate((sent_numbers_whole[p_id], sent_numbers[idx]), 0)
                            # sent_tokens_count_whole[p_id] = np.concatenate(
                            #     (sent_tokens_count_whole[p_id], sent_tokens_count[idx]), 0)

                        sent_sect_wise_rg_whole[p_id] = np.concatenate(
                            (sent_sect_wise_rg_whole[p_id], sent_sect_wise_rg[idx]), 0)
                        sent_sections_txt_whole[p_id] = np.concatenate(
                            (sent_sections_txt_whole[p_id], sent_sections_txt[idx]), 0)


        PRED_LEN = self.args.val_pred_len
        acum_f_sent_labels = 0
        acum_p_sent_labels = 0
        acum_r_sent_labels = 0
        acc_total = 0
        for p_idx, (p_id, sent_scores) in enumerate(sent_scores_whole.items()):
            # sent_true_labels = pickle.load(open("sent_labels_files/pubmedL/val.labels.p", "rb"))
            # section_textual = np.array(section_textual)
            paper_sent_true_labels = np.array(sent_labels_true[p_id])
            if self.is_joint:
                sent_sects_true = np.array(sent_sects_whole_true[p_id])
                sent_sects_pred = np.array(sent_sects_whole_pred[p_id])

            sent_scores = np.array(sent_scores)
            p_src = np.array(paper_srcs[p_id])

            # selected_ids_unsorted = np.argsort(-sent_scores, 0)
            keep_ids = [idx for idx, s in enumerate(p_src) if
                        len(s.replace('.', '').replace(',', '').replace('(', '').replace(')', '').
                            replace('-', '').replace(':', '').replace(';', '').replace('*', '').split()) > 5 and
                        len(s.replace('.', '').replace(',', '').replace('(', '').replace(')', '').
                            replace('-', '').replace(':', '').replace(';', '').replace('*', '').split()) < 100
                        ]

            keep_ids = sorted(keep_ids)

            # top_sent_indexes = top_sent_indexes[top_sent_indexes]
            p_src = p_src[keep_ids]
            sent_scores = sent_scores[keep_ids]
            paper_sent_true_labels = paper_sent_true_labels[keep_ids]

            sent_scores = np.asarray([s - 1.00 for s in sent_scores])

            selected_ids_unsorted = np.argsort(-sent_scores, 0)

            _pred = []
            for j in selected_ids_unsorted:
                if (j >= len(p_src)):
                    continue
                candidate = p_src[j].strip()
                if True:
                    # if (not _block_tri(candidate, _pred)):
                    _pred.append((candidate, j))

                if (len(_pred) == PRED_LEN):
                    break
            _pred = sorted(_pred, key=lambda x: x[1])
            _pred_final_str = '<q>'.join([x[0] for x in _pred])

            preds[p_id] = _pred_final_str
            golds[p_id] = paper_tgts[p_id]
            preds_with_idx[p_id] = _pred
            if p_idx > 10:
                f, p, r = _get_precision_(paper_sent_true_labels, [p[1] for p in _pred])
                if self.is_joint:
                    acc_whole = _get_accuracy_sections(sent_sects_true, sent_sects_pred, [p[1] for p in _pred])
                    acc_total += acc_whole

            else:
                f, p, r = _get_precision_(paper_sent_true_labels, [p[1] for p in _pred], print_few=True, p_id=p_id)
                if self.is_joint:
                    acc_whole = _get_accuracy_sections(sent_sects_true, sent_sects_pred, [p[1] for p in _pred],
                                                       print_few=True, p_id=p_id)
                    acc_total += acc_whole

            acum_f_sent_labels += f
            acum_p_sent_labels += p
            acum_r_sent_labels += r

        for id, pred in preds.items():
            save_pred.write(pred.strip().replace('<q>', ' ') + '\n')
            save_gold.write(golds[id].replace('<q>', ' ').strip() + '\n')

        print(f'Gold: {gold_path}')
        print(f'Prediction: {can_path}')

        r1, r2, rl = self._report_rouge(preds.values(), golds.values())
        stats.set_rl(r1, r2, rl)
        logger.info("F-score: %4.4f, Prec: %4.4f, Recall: %4.4f" % (
        acum_f_sent_labels / len(sent_scores_whole), acum_p_sent_labels / len(sent_scores_whole),
        acum_r_sent_labels / len(sent_scores_whole)))
        if self.is_joint:
            logger.info("Section Accuracy: %4.4f" % (acc_total / len(sent_scores_whole)))


        stats.set_ir_metrics(acum_f_sent_labels / len(sent_scores_whole),
                             acum_p_sent_labels / len(sent_scores_whole),
                             acum_r_sent_labels / len(sent_scores_whole))
        self.valid_rgls.append((r2 + rl) / 2)
        self._report_step(0, step,
                          self.model.uncertainty_loss._sigmas_sq[0] if self.is_joint else 0,
                          self.model.uncertainty_loss._sigmas_sq[1] if self.is_joint else 0,
                          valid_stats=stats)

        if len(self.valid_rgls) > 0:
            if self.min_rl < self.valid_rgls[-1]:
                self.min_rl = self.valid_rgls[-1]
                best_model_saved = True

        return stats, best_model_saved, best_recall_model_saved
Пример #22
0
    def _gradient_accumulation(self, true_batchs, normalization, total_stats,
                               report_stats):

        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            if self.grad_accum_count == 1:
                self.model.zero_grad()

            src = batch.src
            sent_rg_scores = batch.src_sent_labels

            sent_sect_labels = batch.sent_sect_labels
            sent_bin_labels = batch.sent_labels
            # if self.rg_predictor:
            segs = batch.segs
            clss = batch.clss
            mask = batch.mask_src
            mask_cls = batch.mask_cls

            if self.is_joint:
                if not self.rg_predictor:
                    sent_scores, sent_sect_scores, mask, loss, loss_sent, loss_sect = self.model(src, segs, clss, mask,
                                                                                                 mask_cls,
                                                                                                 sent_bin_labels,
                                                                                                 sent_sect_labels)
                else:
                    sent_scores, sent_sect_scores, mask, loss, loss_sent, loss_sect = self.model(src, segs, clss, mask,
                                                                                                 mask_cls,
                                                                                                 sent_rg_scores,
                                                                                                 sent_sect_labels)
                try:
                    acc, pred = self._get_mertrics(sent_sect_scores, sent_sect_labels, mask=mask, task='sent_sect')
                except:
                    logger.info("Accuracy cannot be computed due to some errors in loading approapriate files...")

                batch_stats = Statistics(loss=float(loss.cpu().data.numpy().sum()),
                                         loss_sect=float(loss_sect.cpu().data.numpy().sum()),
                                         loss_sent=float(loss_sent.cpu().data.numpy().sum()), n_docs=normalization,
                                         n_acc=batch.batch_size,
                                         RMSE=self._get_mertrics(sent_scores, sent_rg_scores, mask=mask, task='sent'),
                                         accuracy=acc,
                                         a1=self.model.uncertainty_loss._sigmas_sq[0].item(),
                                         a2=self.model.uncertainty_loss._sigmas_sq[1].item()
                                         )


            else:  # simple

                if not self.rg_predictor:
                    sent_scores, mask, loss, _, _ = self.model(src, segs, clss, mask, mask_cls,
                                                               sent_bin_labels=sent_bin_labels, sent_sect_labels=None)
                else:
                    sent_scores, mask, loss, _, _ = self.model(src, segs, clss, mask, mask_cls,
                                                               sent_bin_labels=sent_rg_scores, sent_sect_labels=None)

                # loss = self.loss(sent_scores, sent_rg_scores.float())

                batch_stats = Statistics(loss=float(loss.cpu().data.numpy().sum()),
                                         RMSE=self._get_mertrics(sent_scores, sent_rg_scores, mask=mask,
                                                                 task='sent'),
                                         n_acc=batch.batch_size,
                                         n_docs=normalization,
                                         a1=self.model.uncertainty_loss._sigmas_sq[0] if self.is_joint else 0,
                                         a2=self.model.uncertainty_loss._sigmas_sq[1] if self.is_joint else 0)

            loss.backward()
            total_stats.update(batch_stats)
            report_stats.update(batch_stats)

            # 4. Update the parameters and statistics.
            if self.grad_accum_count == 1:
                # Multi GPU gradient gather
                if self.n_gpu > 1:
                    grads = [p.grad.data for p in self.model.parameters()
                             if p.requires_grad
                             and p.grad is not None]
                    distributed.all_reduce_and_rescale_tensors(
                        grads, float(1))
                # self.optim.step(report_stats=report_stats)

        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.grad_accum_count > 1:
            if self.n_gpu > 1:
                grads = [p.grad.data for p in self.model.parameters()
                         if p.requires_grad
                         and p.grad is not None]
                distributed.all_reduce_and_rescale_tensors(
                    grads, float(1))
            self.optim.step(report_stats)