Пример #1
0
 def sharded_compute_loss(self, batch, output, shard_size):
     batch_stats = Statistics()
     shard_state = self._make_shard_state(batch, output)
     normalization = batch.tgt[:, 1:].ne(self.padding_idx).sum().item()
     for shard in shards(shard_state, shard_size):
         loss, stats = self._compute_loss(batch, **shard)
         loss.div(float(normalization)).backward()
         batch_stats.update(stats)
     return batch_stats
Пример #2
0
    def validate_rouge(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        preds, golds = self.val_abs(self.args, valid_iter, step)
        best_model_saved = False
        preds_sorted = [s[1] for s in sorted(preds.items())]
        golds_sorted = [g[1] for g in sorted(golds.items())]
        logger.info('Some samples...')
        logger.info('[' + preds_sorted[random.randint(0, len(preds_sorted)-1)] + ']')
        logger.info('[' + preds_sorted[random.randint(0, len(preds_sorted)-1)] + ']')
        logger.info('[' + preds_sorted[random.randint(0, len(preds_sorted)-1)] + ']')
        r1, r2, rl = self._report_rouge(preds_sorted, golds_sorted)

        stats.set_rl(r1, r2, rl)
        self.valid_rgls.append(rl)
        # self._report_step(0, step, valid_stats=stats)

        if len(self.valid_rgls) > 0:
            if self.max_rl < self.valid_rgls[-1]:
                self.max_rl = self.valid_rgls[-1]
                best_model_saved = True

        # with torch.no_grad():
        #     for batch in valid_iter:
        #         src = batch.src
        #         tgt = batch.tgt
        #         segs = batch.segs
        #         clss = batch.clss
        #         mask_src = batch.mask_src
        #         mask_tgt = batch.mask_tgt
        #         mask_cls = batch.mask_cls
        #
        #         outputs, _ = self.model(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)
        #
        #         batch_stats = self.loss.monolithic_compute_loss(batch, outputs)
        #         stats.update(batch_stats)
        #     self._report_step(0, step, valid_stats=stats)
        #
        #     # if len(self.valid_accuracies) > 0:
        #     if self.best_acc < stats.accuracy():
        #         is_best = True
        #         self.best_acc = stats.accuracy()
        #         self.last_best_step.append(step)
            # self.valid_accuracies.append(stats.accuracy())
            return stats, best_model_saved
Пример #3
0
    def _gradient_calculation(self, true_batchs, examples, total_stats,
                              report_stats, step):
        self.model.zero_grad()

        for batch in true_batchs:
            loss = self.model(batch)

            # Topic Model loss
            topic_stats = Statistics(topic_loss=loss.clone().item() /
                                     float(examples))
            loss.div(float(examples)).backward(retain_graph=False)
            total_stats.update(topic_stats)
            report_stats.update(topic_stats)

        if step % 1000 == 0:
            for k in range(self.args.topic_num):
                logger.info(','.join([
                    self.model.voc_id_wrapper.i2w(i)
                    for i in self.model.topic_model.tm1.beta.topk(20, dim=-1)
                    [1][k].tolist()
                ]))
        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.n_gpu > 1:
            grads = [
                p.grad.data for p in self.model.parameters()
                if p.requires_grad and p.grad is not None
            ]
            distributed.all_reduce_and_rescale_tensors(grads, float(1))
        for o in self.optims:
            o.step()
Пример #4
0
 def mono_compute_loss(self, batch, output, states, ex_idx, tgt_idx,
                       mask_tgt, init_logps, trans_logps, ext_logps):
     target = batch.tgt[:, 1:]
     fwd_obs_logps = self._obs_logprobs(output, target, tgt_idx, mask_tgt)
     loss = -self._compute_bwd(fwd_obs_logps, init_logps, trans_logps,
                               ext_logps, ex_idx, states)
     normalization = batch.tgt_len
     batch_stats = Statistics(loss.clone().item(), normalization)
     return batch_stats
Пример #5
0
 def _stats(self, loss, scores, target):
     pred = scores.max(1)[1]
     non_padding = target.ne(self.padding_idx)
     num_correct = pred.eq(target) \
                       .masked_select(non_padding) \
                       .sum() \
                       .item()
     num_non_padding = non_padding.sum().item()
     return Statistics(loss.item(), num_non_padding, num_correct)
Пример #6
0
    def _compute_loss(self, batch, output, target):

        loss = -target * (output+self.eps).log() \
                 - (1-target) * (1-output+self.eps).log()
        loss = torch.sum(loss)

        num_correct = output.gt(0.5).float().eq(target).sum().item()
        num_all = target.size(0)
        stats = Statistics(loss.clone().item(), num_all, num_correct)

        return loss, stats
Пример #7
0
 def sharded_compute_loss(self,
                          batch,
                          output,
                          shard_size,
                          normalization,
                          copy_params=None):
     """
     Args:
       batch (batch) : batch of labeled examples
       output (:obj:`FloatTensor`) :
           output of decoder model `[tgt_len x batch x hidden]`
       attns (dict) : dictionary of attention distributions
           `[tgt_len x batch x src_len]`
       cur_trunc (int) : starting position of truncation window
       trunc_size (int) : length of truncation window
       shard_size (int) : maximum number of examples in a shard
       normalization (int) : Loss is divided by this number
     Returns:
         :obj:`onmt.utils.Statistics`: validation loss statistics
     """
     batch_stats = Statistics()
     shard_state = self._make_shard_state(batch, output, copy_params)
     for shard in shards(shard_state, shard_size):
         output = shard['output']
         target = shard['target']
         if copy_params is not None:
             g = shard['copy_params[1]']
             ext_dist = shard['copy_params[0]']
             if len(shard) > 2:
                 ext_loss = shard['copy_params[2]']
             if len(copy_params) > 2:
                 loss, stats = self._compute_loss(batch, output, target, g,
                                                  ext_dist, ext_loss)
             else:
                 loss, stats = self._compute_loss(batch, output, target, g,
                                                  ext_dist)
         else:
             loss, stats = self._compute_loss(batch, output, target)
         (loss.div(float(normalization)) + ext_loss.mean() * 2).backward()
         batch_stats.update(stats)
     return batch_stats
Пример #8
0
    def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1):

        logger.info('Start training...')
        step =  self.optims[0]._step + 1

        true_batchs = []
        accum = 0
        train_iter = train_iter_fct()

        total_stats = Statistics()
        report_stats = Statistics()
        self._start_report_manager(start_time=total_stats.start_time)

        while step <= train_steps:

            reduce_counter = 0
            for i, batch in enumerate(train_iter):
                if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank):

                    true_batchs.append(batch)
                    accum += 1
                    if accum == self.grad_accum_count:
                        reduce_counter += 1
                        self._gradient_accumulation(true_batchs, total_stats, report_stats)

                        report_stats = self._maybe_report_training(
                            step, train_steps,
                            self.optims[0].learning_rate,
                            report_stats)

                        true_batchs = []
                        accum = 0
                        if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0):
                            self._save(step)

                        step += 1
                        if step > train_steps:
                            break
            train_iter = train_iter_fct()

        return total_stats
Пример #9
0
    def _compute_loss(self, batch, output, target):

        loss = F.kl_div(output, target.float(), reduction="sum")

        prob_size = output.size(-1)
        pred = output.view(-1, prob_size).max(-1)[1]
        gold = target.view(-1, prob_size).max(-1)[1]
        non_padding = target.view(-1, prob_size).sum(-1).ne(0)
        num_correct = pred.eq(gold).masked_select(non_padding).sum().item()
        num_all = torch.sum(target).item()
        stats = Statistics(loss.clone().item(), num_all, num_correct)

        return loss, stats
Пример #10
0
    def compute_loss(self, batch, output, states, ex_idx, tgt_idx, mask_tgt,
                     init_logps, trans_logps, ext_logps):

        target = batch.tgt[:, 1:]
        fwd_obs_logps = self._obs_logprobs(output, target, tgt_idx, mask_tgt)
        loss = -self._compute_bwd(fwd_obs_logps, init_logps, trans_logps,
                                  ext_logps, ex_idx, states)

        #normalization = self._get_normalization(tgt_idx)
        normalization = batch.tgt_len
        batch_stats = Statistics(loss.clone().item(), normalization)
        loss.div(float(normalization)).backward()
        return batch_stats
Пример #11
0
    def validate(self, valid_iter, step=0):
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                if self.args.mode == 'validate':
                    src = batch.src
                    tgt = batch.tgt

                    pmt_msk = batch.pmt_msk
                    states = batch.states
                    ex_idx = batch.ex_idx
                    tgt_idx = batch.tgt_idx

                    mask_src = batch.mask_src
                    mask_tgt = batch.mask_tgt

                    outputs, _ = self.model(src, tgt, mask_src, pmt_msk, ex_idx)
                    init_logps, trans_logps = self.model.trans_logprobs()
                    ext_logps = self.model.external_logprobs()
                    batch_stats = self.loss.mono_compute_loss(batch, outputs, states,
                                                            ex_idx, tgt_idx, mask_tgt,
                                                            init_logps, trans_logps,
                                                            ext_logps)
                else:
                    src = batch.src
                    tgt = batch.tgt
                    segs = batch.segs
                    mask_src = batch.mask_src
                    mask_tgt = batch.mask_tgt

                    outputs, _ = self.model(src, tgt, segs, mask_src, mask_tgt)
                    batch_stats = self.loss.monolithic_compute_loss(batch, outputs)

                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)

            return stats
Пример #12
0
    def validate(self, valid_iter_fct, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()
        valid_iter = valid_iter_fct()
        cn = 0
        for _ in valid_iter:
            cn+=1
        best_model_saved = False
        valid_iter = valid_iter_fct()
        with torch.no_grad():
            for batch in tqdm(valid_iter, total=cn):
                src = batch.src
                tgt = batch.tgt
                segs = batch.segs
                clss = batch.clss
                mask_src = batch.mask_src
                mask_tgt = batch.mask_tgt
                mask_cls = batch.mask_cls

                outputs, _ = self.model(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)

                batch_stats = self.loss.monolithic_compute_loss(batch, outputs)
                stats.update(batch_stats)

            self.valid_rgls.append(stats.accuracy())
            if len(self.valid_rgls) > 0:
                if self.max_rl < self.valid_rgls[-1]:
                    self.max_rl = self.valid_rgls[-1]
                    best_model_saved = True

            self._report_step(0, step, valid_stats=stats)
            return stats, best_model_saved
Пример #13
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                src = batch.src
                tgt = batch.tgt
                segs = batch.segs
                clss = batch.clss
                mask_src = batch.mask_src
                mask_tgt = batch.mask_tgt
                mask_cls = batch.mask_cls
                # pre
                # outputs, _ = self.model(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)
                # batch_stats = self.loss.monolithic_compute_loss(batch, outputs)
                ### graph
                # ent_list = batch.ent_list
                # rel_list = batch.rel_list
                # adj = batch.adj
                outputs, scores, src_context, graph_context, top_vec, ent_top_vec, _ = self.model(
                    src, tgt, segs, clss, mask_src, mask_tgt, mask_cls, batch)
                batch_stats = self.loss.monolithic_compute_loss(
                    batch, outputs, src_context, graph_context, batch.ent_src,
                    ent_top_vec, self.copy)

                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            with open('./score.txt', 'a+') as f:
                f.write('step: %g \n' % step)
                f.write('xent: %g \n' % stats.ppl())
            return stats
Пример #14
0
 def _stats(self, loss, scores, target):
     """
     Args:
         loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
         scores (:obj:`FloatTensor`): a score for each possible output
         target (:obj:`FloatTensor`): true targets
     Returns:
         :obj:`onmt.utils.Statistics` : statistics for this batch.
     """
     pred = scores.max(1)[1]
     non_padding = target.ne(self.padding_idx)
     num_correct = pred.eq(target).masked_select(non_padding).sum().item()
     num_non_padding = non_padding.sum().item()
     return Statistics(loss.item(), num_non_padding, num_correct)