Exemple #1
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        if self.args.acc_reporter == 1:
            stats = acc_reporter.Statistics()
        else:
            stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                src = batch.src
                labels = batch.src_sent_labels
                segs = batch.segs
                clss = batch.clss
                mask = batch.mask_src
                mask_cls = batch.mask_cls

                if self.args.ext_sum_dec:
                    sent_scores, mask = self.model(src, segs, clss, mask, mask_cls, labels)  # B, tgt_len custom_num
                    tgt_len = 3
                    _, labels_id = torch.topk(labels, k=tgt_len)  # B, tgt_len
                    labels_id, _ = torch.sort(labels_id)
                    # nsent 100 weight_up 20
                    weight = torch.linspace(start=1, end=self.args.weight_up, steps=self.args.max_src_nsents).type_as(sent_scores)
                    # self.max_class = max(self.max_class,torch.max(labels_id+1).item())
                    # weight = weight[:self.max_class]
                    weight = weight[:sent_scores.size(-1)]
                    # weight = torch.ones(self.args.max_src_nsents)
                    loss = F.nll_loss(
                        F.log_softmax(
                            sent_scores.view(-1, sent_scores.size(-1)),
                            dim=-1,
                            dtype=torch.float32,
                        ),
                        labels_id.view(-1),  # bsz sent
                        weight=weight,
                        reduction='sum',
                        ignore_index=-1,
                    )
                    prediction = torch.argmax(sent_scores, dim=-1)
                    if (self.optim._step + 1) % self.args.print_every == 0:
                        logger.info(
                            'train prediction: %s |label %s ' % (str(prediction), str(labels_id)))
                    accuracy = torch.div(torch.sum(torch.equal(prediction, labels_id).float()), tgt_len)
                else:
                    sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)  # B, custom_N
                    loss = self.loss(sent_scores, labels.float())
                    loss = (loss * mask.float()).sum()
                    tgt_len = 3
                    _, labels_id = torch.topk(labels, k=tgt_len)  # B, tgt_len
                    labels_id, _ = torch.sort(labels_id)
                    _, prediction = torch.topk(sent_scores, k=tgt_len)
                    prediction,_ = torch.sort(labels_id)
                    if (self.optim._step + 1) % self.args.print_every == 0:
                        logger.info(
                            'train prediction: %s |label %s ' % (str(prediction), str(labels_id)))
                    accuracy = torch.div(torch.sum(torch.equal(prediction, labels_id).float()), tgt_len)
                if self.args.acc_reporter == 1:
                    batch_stats = Statistics(float(loss.cpu().data.numpy()),accuracy, len(labels))
                else:
                    batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels))
                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            return stats
Exemple #2
0
    def _gradient_accumulation(self, true_batchs, normalization, total_stats,
                               report_stats):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            if self.grad_accum_count == 1:
                self.model.zero_grad()
            # src = torch.tensor(self._pad(pre_src, 0))
            # segs = torch.tensor(self._pad(pre_segs, 0))
            # mask_src = torch.logical_not(src == 0)
            # clss = torch.tensor(self._pad(pre_clss, -1))
            # src_sent_labels = torch.tensor(self._pad(pre_src_sent_labels, 0))
            # mask_cls = torch.logical_not(clss == -1)
            # clss[clss == -1] = 0
            # setattr(self, 'clss' + postfix, clss.to(device))
            # setattr(self, 'mask_cls' + postfix, mask_cls.to(device))
            # setattr(self, 'src_sent_labels' + postfix, src_sent_labels.to(device))
            # setattr(self, 'src' + postfix, src.to(device))
            # setattr(self, 'segs' + postfix, segs.to(device))
            # setattr(self, 'mask_src' + postfix, mask_src.to(device))
            # # 下面都是要预测的给他pad -1, 意思是看到-1 就停止算loss, 不用计算mask ,mask 是作为输入时才要的
            # org_sent_labels = torch.tensor(self._pad(org_sent_labels, -1))
            # setattr(self, 'org_sent_labels' + postfix, org_sent_labels.to(device))
            # poss = torch.tensor(self._pad(poss, -1))
            # setattr(self, 'poss' + postfix, poss.to(device))

            if self.args.jigsaw == 'jigsaw_lab':  # jigsaw_lab 各自预测的那种,失败的尝试
                logits = self.model(batch.src_s, batch.segs_s, batch.clss_s, batch.mask_src_s, batch.mask_cls_s)# bsz tgt_len nsent
                # bsz, sent, max-sent_num
                # mask = batch.mask_cls_s[:, :, None].float()
                # loss = self.loss(sent_scores, batch.poss_s.float())
                loss = F.nll_loss(
                    F.log_softmax(
                        logits.view(-1, logits.size(-1)),
                        dim=-1,
                        dtype=torch.float32,
                    ),
                    batch.poss_s.view(-1), # bsz sent
                    reduction='sum',
                    ignore_index=-1,
                )
                prediction = torch.argmax(logits, dim=-1)
                if (self.optim._step + 1) % self.args.print_every == 0:
                    logger.info(
                        'train prediction: %s |label %s ' % (str(prediction), str(batch.poss_s)))
                accuracy = torch.div(torch.sum(torch.equal(prediction, batch.poss_s) * batch.mask_cls_s),
                                     torch.sum(batch.mask_cls_s)) * len(logits)

                # loss = (loss * batch.mask_cls_s.float()).sum()
                # print('train prediction: %s |label %s ' % (str(torch.argmax(logits, dim=-1)[0]), str(batch.poss_s[0])))
                # logger.info('train prediction: %s |label %s ' % (str(torch.argmax(logits, dim=-1)[0]), str(batch.poss_s[0])))
                # (loss / loss.numel()).backward()
            else:  #self.args.jigsaw == 'jigsaw_dec':    jigsaw decoder
                poss_s = batch.poss_s
                mask_poss = torch.eq(poss_s, -1)
                poss_s.masked_fill_(mask_poss, 1e4)
                # poss_s[i] [5,1,4,0,2,3,-1,-1]->[5,1,4,0,2,3,1e4,1e4] dec_labels[i] [3,1,xxx,6,7]
                dec_labels = torch.argsort(poss_s, dim=1)
                logits,_ = self.model(batch.src_s, batch.segs_s, batch.clss_s, batch.mask_src_s, batch.mask_cls_s, dec_labels)
                final_dec_labels = dec_labels.masked_fill(mask_poss, -1)
                loss = F.nll_loss(
                    F.log_softmax(
                        logits.view(-1, logits.size(-1)),
                        dim=-1,
                        dtype=torch.float32,
                    ),
                    final_dec_labels.view(-1),  # bsz sent
                    reduction='sum',
                    ignore_index=-1,
                )
                # loss = (loss * batch.mask_cls_s.float()).sum()
                # (loss / loss.numel()).backward()
                prediction = torch.argmax(logits, dim=-1)
                if (self.optim._step + 1) % self.args.print_every == 0:
                    logger.info(
                        'train prediction: %s |label %s ' % (str(prediction), str(batch.poss_s)))
                accuracy = torch.div(torch.sum(torch.equal(prediction, batch.poss_s) * batch.mask_cls_s),
                                     torch.sum(batch.mask_cls_s)) * len(logits)
            with amp.scale_loss((loss / loss.numel()), self.optim.optimizer) as scaled_loss:
                scaled_loss.backward()
            # loss.div(float(normalization)).backward()
            if self.args.acc_reporter:
                batch_stats = acc_reporter.Statistics(float(loss.cpu().data.numpy()), accuracy, normalization)
            else:
                batch_stats = Statistics(float(loss.cpu().data.numpy()), normalization)

            total_stats.update(batch_stats)
            report_stats.update(batch_stats)

            # 4. Update the parameters and statistics.
            if self.grad_accum_count == 1:
                # Multi GPU gradient gather
                if self.n_gpu > 1:
                    grads = [p.grad.data for p in self.model.parameters()
                             if p.requires_grad
                             and p.grad is not None]
                    distributed.all_reduce_and_rescale_tensors(
                        grads, float(1))
                self.optim.step()

        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.grad_accum_count > 1:
            if self.n_gpu > 1:
                grads = [p.grad.data for p in self.model.parameters()
                         if p.requires_grad
                         and p.grad is not None]
                distributed.all_reduce_and_rescale_tensors(
                    grads, float(1))
            self.optim.step()
Exemple #3
0
    def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1):
        """
        The main training loops.
        by iterating over training data (i.e. `train_iter_fct`)
        and running validation (i.e. iterating over `valid_iter_fct`

        Args:
            train_iter_fct(function): a function that returns the train
                iterator. e.g. something like
                train_iter_fct = lambda: generator(*args, **kwargs)
            valid_iter_fct(function): same as train_iter_fct, for valid data
            train_steps(int):
            valid_steps(int):
            save_checkpoint_steps(int):

        Return:
            None
        """
        logger.info('Start training...')

        # step =  self.optim._step + 1
        step = self.optim._step + 1
        true_batchs = []
        accum = 0
        normalization = 0
        train_iter = train_iter_fct()
        if self.args.acc_reporter == 1:
            total_stats = acc_reporter.Statistics()
            report_stats = acc_reporter.Statistics()
        else:
            total_stats = Statistics()
            report_stats = Statistics()
        self._start_report_manager(start_time=total_stats.start_time)

        while step <= train_steps:

            reduce_counter = 0
            for i, batch in enumerate(train_iter):
                if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank):

                    true_batchs.append(batch)
                    normalization += batch.batch_size
                    accum += 1
                    if accum == self.grad_accum_count:  # 20200318 1703 似乎step就是num_updates
                        reduce_counter += 1
                        if self.n_gpu > 1:
                            normalization = sum(distributed
                                                .all_gather_list
                                                (normalization))

                        self._gradient_accumulation(
                            true_batchs, normalization, total_stats,
                            report_stats)

                        report_stats = self._maybe_report_training(
                            step, train_steps,
                            self.optim.learning_rate,
                            report_stats)

                        true_batchs = []
                        accum = 0
                        normalization = 0
                        if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0):
                            self._save(step)

                        step += 1
                        if step > train_steps:
                            break
            train_iter = train_iter_fct()

        return total_stats
Exemple #4
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        if self.args.acc_reporter:
            stats = acc_reporter.Statistics()
        else:
            stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                # src = batch.src
                # labels = batch.src_sent_labels
                # segs = batch.segs
                # clss = batch.clss
                # mask = batch.mask_src
                # mask_cls = batch.mask_cls

                # sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)
                if self.args.jigsaw == 'jigsaw_lab':  # jigsaw_lab 3.31 23:38 发现之前忘了改validate, 早上起来再跑一次看看
                    logits = self.model(batch.src_s, batch.segs_s, batch.clss_s, batch.mask_src_s, batch.mask_cls_s)
                    # bsz, sent, max-sent_num
                    # mask = batch.mask_cls_s[:, :, None].float()
                    # loss = self.loss(sent_scores, batch.poss_s.float())
                    loss = F.nll_loss(
                        F.log_softmax(
                            logits.view(-1, logits.size(-1)),
                            dim=-1,
                            dtype=torch.float32,
                        ),
                        batch.poss_s.view(-1),  # bsz sent
                        reduction='sum',
                        ignore_index=-1,
                    )
                    prediction = torch.argmax(logits, dim=-1)
                    if (self.optim._step + 1) % self.args.print_every == 0:
                        logger.info(
                            'train prediction: %s |label %s ' % (str(prediction), str(batch.poss_s)))
                    accuracy = torch.div(torch.sum(torch.equal(prediction, batch.poss_s) * batch.mask_cls_s),
                                         torch.sum(batch.mask_cls_s)) * len(logits)
                elif self.args.jigsaw == 'jigsaw_dec':  # jigsaw decoder
                    poss_s = batch.poss_s
                    mask_poss = torch.eq(poss_s, -1)
                    poss_s.masked_fill_(mask_poss, 1e4)
                    # poss_s[i] [5,1,4,0,2,3,-1,-1]->[5,1,4,0,2,3,1e4,1e4]
                    dec_labels = torch.argsort(poss_s, dim=1)  # dec_labels[i] [3,1,xxx,6,7]
                    logits = self.model(batch.src_s, batch.segs_s, batch.clss_s, batch.mask_src_s, batch.mask_cls_s,
                                        dec_labels)
                    final_dec_labels = dec_labels.masked_fill(mask_poss, -1)  # final_dec_labels[i] [3,1,xxx,-1,-1]
                    loss = F.nll_loss(
                        F.log_softmax(
                            logits.view(-1, logits.size(-1)),
                            dim=-1,
                            dtype=torch.float32,
                        ),
                        final_dec_labels.view(-1),  # bsz sent
                        reduction='sum',
                        ignore_index=-1,
                    )

                    # loss = (loss * batch.mask_cls_s.float()).sum()
                    prediction = torch.argmax(logits, dim=-1)
                    if (self.optim._step + 1) % self.args.print_every == 0:
                        logger.info(
                            'train prediction: %s |label %s ' % (str(prediction), str(batch.poss_s)))
                    accuracy = torch.div(torch.sum(torch.equal(prediction, batch.final_dec_labels) * batch.mask_cls_s),
                                         torch.sum(batch.mask_cls_s)) * len(logits)


                # loss = self.loss(sent_scores, labels.float())
                # loss = (loss * mask.float()).sum()
                if self.args.acc_reporter:
                    batch_stats = acc_reporter.Statistics(float(loss.cpu().data.numpy()), accuracy, len(batch.poss_s))
                else:
                    batch_stats = Statistics(float(loss.cpu().data.numpy()), len(batch.poss_s))
                stats.update(batch_stats)

            self._report_step(0, step, valid_stats=stats)
            return stats