Exemplo n.º 1
0
    def _gradient_calculation(self, true_batchs, examples, total_stats,
                              report_stats, step):
        self.model.zero_grad()

        for batch in true_batchs:
            loss = self.model(batch)

            # Topic Model loss
            topic_stats = Statistics(topic_loss=loss.clone().item() /
                                     float(examples))
            loss.div(float(examples)).backward(retain_graph=False)
            total_stats.update(topic_stats)
            report_stats.update(topic_stats)

        if step % 1000 == 0:
            for k in range(self.args.topic_num):
                logger.info(','.join([
                    self.model.voc_id_wrapper.i2w(i)
                    for i in self.model.topic_model.tm1.beta.topk(20, dim=-1)
                    [1][k].tolist()
                ]))
        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.n_gpu > 1:
            grads = [
                p.grad.data for p in self.model.parameters()
                if p.requires_grad and p.grad is not None
            ]
            distributed.all_reduce_and_rescale_tensors(grads, float(1))
        for o in self.optims:
            o.step()
Exemplo n.º 2
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                src = batch.src
                tgt = batch.tgt
                segs = batch.segs
                clss = batch.clss
                mask_src = batch.mask_src
                mask_tgt = batch.mask_tgt
                mask_cls = batch.mask_cls

                outputs, _ = self.model(src, tgt, segs, clss, mask_src,
                                        mask_tgt, mask_cls)

                batch_stats = self.loss.monolithic_compute_loss(batch, outputs)
                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            return stats
Exemplo n.º 3
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                src = batch.src
                tgt = batch.tgt
                segs = batch.segs
                clss = batch.clss
                mask_src = batch.mask_src
                mask_tgt = batch.mask_tgt
                mask_cls = batch.mask_cls

                if self.args.task == 'hybrid':
                    # outputs, scores, copy_params = self.model(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)
                    if self.args.oracle or self.args.hybrid_loss:
                        labels = batch.src_sent_labels
                        outputs, scores, copy_params = self.model(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls,
                                                                  labels)
                    else:
                        outputs, scores, copy_params = self.model(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)
                    batch_stats = self.loss.monolithic_compute_loss(batch, outputs, copy_params)
                else:
                    outputs, _ = self.model(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)
                    batch_stats = self.loss.monolithic_compute_loss(batch, outputs)
                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            return stats
Exemplo n.º 4
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                src = batch.src
                tgt = batch.tgt
                segs = batch.segs
                clss = batch.clss
                mask_src = batch.mask_src
                mask_tgt = batch.mask_tgt
                mask_cls = batch.mask_cls

                if self.args.task == 'hybrid':
                    # outputs, scores, copy_params = self.model(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)
                    if self.args.oracle or self.args.hybrid_loss:
                        labels = batch.src_sent_labels
                        outputs, scores, copy_params = self.model(
                            src, tgt, segs, clss, mask_src, mask_tgt, mask_cls,
                            labels)
                    else:
                        outputs, scores, copy_params = self.model(
                            src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)
                    batch_stats = self.loss.monolithic_compute_loss(
                        batch, outputs, copy_params)
                    # bottled_output = self.loss._bottle(outputs)
                    # print("bottled_output ", bottled_output.size())
                    # print("copy_params ", copy_params)
                    # print(bottled_output)
                    # exit()
                    # scores = self.loss.generator(bottled_output)
                    # if copy_params:
                    # print("ex prob")
                    # print(copy_params[0].size())
                    # print("g")
                    # print(copy_params[1].size())
                    # new_scores = copy_params[1]
                    # print("new_scores")
                    # print(new_scores.size())
                    # print("scores softmax: ", scores.size())
                    # scores = scores * copy_params[0].view(scores.size(0), scores.size(1)) + new_scores.view(scores.size(0), scores.size(1))
                    # scores = torch.log(scores)
                else:
                    outputs, _ = self.model(src, tgt, segs, clss, mask_src,
                                            mask_tgt, mask_cls)
                    batch_stats = self.loss.monolithic_compute_loss(
                        batch, outputs)

                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            return stats
Exemplo n.º 5
0
 def mono_compute_loss(self, batch, output, states, ex_idx, tgt_idx,
                       mask_tgt, init_logps, trans_logps, ext_logps):
     target = batch.tgt[:, 1:]
     fwd_obs_logps = self._obs_logprobs(output, target, tgt_idx, mask_tgt)
     loss = -self._compute_bwd(fwd_obs_logps, init_logps, trans_logps,
                               ext_logps, ex_idx, states)
     normalization = batch.tgt_len
     batch_stats = Statistics(loss.clone().item(), normalization)
     return batch_stats
Exemplo n.º 6
0
 def _stats(self, loss, scores, target):
     pred = scores.max(1)[1]
     non_padding = target.ne(self.padding_idx)
     num_correct = pred.eq(target) \
                       .masked_select(non_padding) \
                       .sum() \
                       .item()
     num_non_padding = non_padding.sum().item()
     return Statistics(loss.item(), num_non_padding, num_correct)
Exemplo n.º 7
0
 def sharded_compute_loss(self, batch, output, shard_size):
     batch_stats = Statistics()
     shard_state = self._make_shard_state(batch, output)
     normalization = batch.tgt[:, 1:].ne(self.padding_idx).sum().item()
     for shard in shards(shard_state, shard_size):
         loss, stats = self._compute_loss(batch, **shard)
         loss.div(float(normalization)).backward()
         batch_stats.update(stats)
     return batch_stats
Exemplo n.º 8
0
    def attn_debug(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        json_objs = []
        with torch.no_grad():
            for batch in valid_iter:
                # batch size has to be 1
                assert batch.src.size(0) == 1
                src = batch.src
                tgt = batch.tgt
                segs = batch.segs
                clss = batch.clss
                alignment = batch.alignment

                mask_src = batch.mask_src
                mask_tgt = batch.mask_tgt
                mask_cls = batch.mask_cls
                mask_alg = batch.mask_alg

                outputs, state, src_mem_bank = self.model(src,
                                                          tgt,
                                                          segs,
                                                          clss,
                                                          mask_src,
                                                          mask_tgt,
                                                          mask_cls,
                                                          attn_debug=True)
                batch_stats = self.loss.monolithic_compute_loss(batch, \
                        outputs, \
                        src_mem_bank, \
                        alignment, \
                        mask_alg, \
                        mask_tgt, \
                        saved_attn = state.s2t_attn, \
                        loss_ws=state.loss_ws)

                json_obj = {}
                ex_loss = batch_stats.loss
                str_src = self.ids_2_toks(src)[0]
                str_tgt = self.ids_2_toks(tgt[:, :-1])[0]
                json_obj["ex_loss"] = ex_loss
                json_obj["src"] = str_src
                json_obj["tgt"] = str_tgt
                json_obj["attn"] = [
                    attn.cpu().tolist() for attn in state.s2t_attn
                ]
                json_objs.append(json_obj)

        return json_objs
Exemplo n.º 9
0
    def sharded_compute_loss(self,
                             batch,
                             output,
                             shard_size,
                             normalization,
                             copy_params=None):
        """Compute the forward loss and backpropagate.  Computation is done
        with shards and optionally truncation for memory efficiency.

        Also supports truncated BPTT for long sequences by taking a
        range in the decoder output sequence to back propagate in.
        Range is from `(cur_trunc, cur_trunc + trunc_size)`.

        Note sharding is an exact efficiency trick to relieve memory
        required for the generation buffers. Truncation is an
        approximate efficiency trick to relieve the memory required
        in the RNN buffers.

        Args:
          batch (batch) : batch of labeled examples
          output (:obj:`FloatTensor`) :
              output of decoder model `[tgt_len x batch x hidden]`
          attns (dict) : dictionary of attention distributions
              `[tgt_len x batch x src_len]`
          cur_trunc (int) : starting position of truncation window
          trunc_size (int) : length of truncation window
          shard_size (int) : maximum number of examples in a shard
          normalization (int) : Loss is divided by this number

        Returns:
            :obj:`onmt.utils.Statistics`: validation loss statistics

        """
        batch_stats = Statistics()
        shard_state = self._make_shard_state(batch, output, copy_params)
        for shard in shards(shard_state, shard_size):
            output = shard['output']
            target = shard['target']
            if copy_params is not None:
                g = shard['copy_params[1]']
                ext_dist = shard['copy_params[0]']
                if len(shard) > 2:
                    ext_loss = shard['copy_params[2]']
                if len(copy_params) > 2:
                    loss, stats = self._compute_loss(batch, output, target, g,
                                                     ext_dist, ext_loss)
                else:
                    loss, stats = self._compute_loss(batch, output, target, g,
                                                     ext_dist)
            else:
                loss, stats = self._compute_loss(batch, output, target)
            (loss.div(float(normalization)) + ext_loss.mean() / 2).backward()
            batch_stats.update(stats)

        return batch_stats
Exemplo n.º 10
0
    def sharded_compute_loss(self,
                             batch,
                             output,
                             shard_size,
                             normalization,
                             src_mem_bank=None,
                             alignment=None,
                             mask_alg=None,
                             mask_tgt=None):
        """Compute the forward loss and backpropagate.  Computation is done
        with shards and optionally truncation for memory efficiency.

        Also supports truncated BPTT for long sequences by taking a
        range in the decoder output sequence to back propagate in.
        Range is from `(cur_trunc, cur_trunc + trunc_size)`.

        Note sharding is an exact efficiency trick to relieve memory
        required for the generation buffers. Truncation is an
        approximate efficiency trick to relieve the memory required
        in the RNN buffers.

        Args:
          batch (batch) : batch of labeled examples
          output (:obj:`FloatTensor`) :
              output of decoder model `[tgt_len x batch x hidden]`
          attns (dict) : dictionary of attention distributions
              `[tgt_len x batch x src_len]`
          cur_trunc (int) : starting position of truncation window
          trunc_size (int) : length of truncation window
          shard_size (int) : maximum number of examples in a shard
          normalization (int) : Loss is divided by this number

        Returns:
            :obj:`onmt.utils.Statistics`: validation loss statistics

        """
        '''
        batch_stats = Statistics()
        shard_state = self._make_shard_state(batch, output)
        for shard in shards(shard_state, shard_size):
            loss, stats = self._compute_loss(batch, **shard)
            loss.div(float(normalization)).backward()
            batch_stats.update(stats)
        '''

        batch_stats = Statistics()

        loss, stats = self._compute_loss(batch, output, batch.tgt[:, 1:],
                                         src_mem_bank, alignment, mask_alg,
                                         mask_tgt)

        loss.div(float(normalization)).backward()
        batch_stats.update(stats)

        return batch_stats
Exemplo n.º 11
0
    def _compute_loss(self, batch, output, target):

        loss = -target * (output+self.eps).log() \
                 - (1-target) * (1-output+self.eps).log()
        loss = torch.sum(loss)

        num_correct = output.gt(0.5).float().eq(target).sum().item()
        num_all = target.size(0)
        stats = Statistics(loss.clone().item(), num_all, num_correct)

        return loss, stats
Exemplo n.º 12
0
    def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1):

        logger.info('Start training...')
        step =  self.optims[0]._step + 1

        true_batchs = []
        accum = 0
        train_iter = train_iter_fct()

        total_stats = Statistics()
        report_stats = Statistics()
        self._start_report_manager(start_time=total_stats.start_time)

        while step <= train_steps:

            reduce_counter = 0
            for i, batch in enumerate(train_iter):
                if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank):

                    true_batchs.append(batch)
                    accum += 1
                    if accum == self.grad_accum_count:
                        reduce_counter += 1
                        self._gradient_accumulation(true_batchs, total_stats, report_stats)

                        report_stats = self._maybe_report_training(
                            step, train_steps,
                            self.optims[0].learning_rate,
                            report_stats)

                        true_batchs = []
                        accum = 0
                        if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0):
                            self._save(step)

                        step += 1
                        if step > train_steps:
                            break
            train_iter = train_iter_fct()

        return total_stats
Exemplo n.º 13
0
    def validate_rouge(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        preds, golds = self.val_abs(self.args, valid_iter, step)
        best_model_saved = False
        preds_sorted = [s[1] for s in sorted(preds.items())]
        golds_sorted = [g[1] for g in sorted(golds.items())]
        logger.info('Some samples...')
        logger.info('[' + preds_sorted[random.randint(0, len(preds_sorted)-1)] + ']')
        logger.info('[' + preds_sorted[random.randint(0, len(preds_sorted)-1)] + ']')
        logger.info('[' + preds_sorted[random.randint(0, len(preds_sorted)-1)] + ']')
        r1, r2, rl = self._report_rouge(preds_sorted, golds_sorted)

        stats.set_rl(r1, r2, rl)
        self.valid_rgls.append(rl)
        # self._report_step(0, step, valid_stats=stats)

        if len(self.valid_rgls) > 0:
            if self.max_rl < self.valid_rgls[-1]:
                self.max_rl = self.valid_rgls[-1]
                best_model_saved = True

        # with torch.no_grad():
        #     for batch in valid_iter:
        #         src = batch.src
        #         tgt = batch.tgt
        #         segs = batch.segs
        #         clss = batch.clss
        #         mask_src = batch.mask_src
        #         mask_tgt = batch.mask_tgt
        #         mask_cls = batch.mask_cls
        #
        #         outputs, _ = self.model(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)
        #
        #         batch_stats = self.loss.monolithic_compute_loss(batch, outputs)
        #         stats.update(batch_stats)
        #     self._report_step(0, step, valid_stats=stats)
        #
        #     # if len(self.valid_accuracies) > 0:
        #     if self.best_acc < stats.accuracy():
        #         is_best = True
        #         self.best_acc = stats.accuracy()
        #         self.last_best_step.append(step)
            # self.valid_accuracies.append(stats.accuracy())
            return stats, best_model_saved
Exemplo n.º 14
0
    def _compute_loss(self, batch, output, target):

        loss = F.kl_div(output, target.float(), reduction="sum")

        prob_size = output.size(-1)
        pred = output.view(-1, prob_size).max(-1)[1]
        gold = target.view(-1, prob_size).max(-1)[1]
        non_padding = target.view(-1, prob_size).sum(-1).ne(0)
        num_correct = pred.eq(gold).masked_select(non_padding).sum().item()
        num_all = torch.sum(target).item()
        stats = Statistics(loss.clone().item(), num_all, num_correct)

        return loss, stats
Exemplo n.º 15
0
    def compute_loss(self, batch, output, states, ex_idx, tgt_idx, mask_tgt,
                     init_logps, trans_logps, ext_logps):

        target = batch.tgt[:, 1:]
        fwd_obs_logps = self._obs_logprobs(output, target, tgt_idx, mask_tgt)
        loss = -self._compute_bwd(fwd_obs_logps, init_logps, trans_logps,
                                  ext_logps, ex_idx, states)

        #normalization = self._get_normalization(tgt_idx)
        normalization = batch.tgt_len
        batch_stats = Statistics(loss.clone().item(), normalization)
        loss.div(float(normalization)).backward()
        return batch_stats
Exemplo n.º 16
0
 def _stats(self, loss, scores, target):
     """
     Args:
         loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
         scores (:obj:`FloatTensor`): a score for each possible output
         target (:obj:`FloatTensor`): true targets
     Returns:
         :obj:`onmt.utils.Statistics` : statistics for this batch.
     """
     pred = scores.max(1)[1]
     non_padding = target.ne(self.padding_idx)
     num_correct = pred.eq(target).masked_select(non_padding).sum().item()
     num_non_padding = non_padding.sum().item()
     return Statistics(loss.item(), num_non_padding, num_correct)
Exemplo n.º 17
0
    def sharded_compute_loss(self, batch, output,
                              shard_size, normalization, optim):
        """Compute the forward loss and backpropagate.  Computation is done
        with shards and optionally truncation for memory efficiency.

        Also supports truncated BPTT for long sequences by taking a
        range in the decoder output sequence to back propagate in.
        Range is from `(cur_trunc, cur_trunc + trunc_size)`.

        Note sharding is an exact efficiency trick to relieve memory
        required for the generation buffers. Truncation is an
        approximate efficiency trick to relieve the memory required
        in the RNN buffers.

        Args:
          batch (batch) : batch of labeled examples
          output (:obj:`FloatTensor`) :
              output of decoder model `[tgt_len x batch x hidden]`
          attns (dict) : dictionary of attention distributions
              `[tgt_len x batch x src_len]`
          cur_trunc (int) : starting position of truncation window
          trunc_size (int) : length of truncation window
          shard_size (int) : maximum number of examples in a shard
          normalization (int) : Loss is divided by this number

        Returns:
            :obj:`onmt.utils.Statistics`: validation loss statistics

        """
        batch_stats = Statistics()
        shard_state = self._make_shard_state(batch, output)
        for shard in shards(shard_state, shard_size):
            # out = torch.rand(size=(1, 999, 768))
            # out[:, :992, :] = shard['output'].repeat(1, 999 // 8, 1)
            #
            # out[:, 992:, :] = shard['output'][:, :7, :]
            # shard['output'] = out
            import pdb;pdb.set_trace()

            loss, stats = self._compute_loss(batch, **shard)

            # with amp.scale_loss(loss, optim.optimizer) as scaled_loss:
            loss.div(float(normalization)).backward()

            batch_stats.update(stats)

        return batch_stats
Exemplo n.º 18
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                src = batch.src
                tgt = batch.tgt
                segs = batch.segs
                clss = batch.clss

                tgt_eng = batch.tgt_eng
                if not hasattr(batch, 'tgt_segs'):
                    # print("this ")
                    tgt_segs = torch.ones(tgt.size()).long().cuda()
                else:
                    tgt_segs = batch.tgt_segs

                mask_src = batch.mask_src
                mask_tgt = batch.mask_tgt
                mask_cls = batch.mask_cls

                outputs, _, _ = self.model(src,
                                           tgt,
                                           segs,
                                           clss,
                                           mask_src,
                                           mask_tgt,
                                           mask_cls,
                                           tgt_eng=tgt_eng,
                                           tgt_segs=tgt_segs)

                batch_stats = self.loss.monolithic_compute_loss(batch, outputs)
                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            return stats
Exemplo n.º 19
0
 def sharded_compute_loss(self,
                          batch,
                          output,
                          shard_size,
                          normalization,
                          copy_params=None):
     """
     Args:
       batch (batch) : batch of labeled examples
       output (:obj:`FloatTensor`) :
           output of decoder model `[tgt_len x batch x hidden]`
       attns (dict) : dictionary of attention distributions
           `[tgt_len x batch x src_len]`
       cur_trunc (int) : starting position of truncation window
       trunc_size (int) : length of truncation window
       shard_size (int) : maximum number of examples in a shard
       normalization (int) : Loss is divided by this number
     Returns:
         :obj:`onmt.utils.Statistics`: validation loss statistics
     """
     batch_stats = Statistics()
     shard_state = self._make_shard_state(batch, output, copy_params)
     for shard in shards(shard_state, shard_size):
         output = shard['output']
         target = shard['target']
         if copy_params is not None:
             g = shard['copy_params[1]']
             ext_dist = shard['copy_params[0]']
             if len(shard) > 2:
                 ext_loss = shard['copy_params[2]']
             if len(copy_params) > 2:
                 loss, stats = self._compute_loss(batch, output, target, g,
                                                  ext_dist, ext_loss)
             else:
                 loss, stats = self._compute_loss(batch, output, target, g,
                                                  ext_dist)
         else:
             loss, stats = self._compute_loss(batch, output, target)
         (loss.div(float(normalization)) + ext_loss.mean() * 2).backward()
         batch_stats.update(stats)
     return batch_stats
Exemplo n.º 20
0
    def validate(self, valid_iter, step=0):
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                if self.args.mode == 'validate':
                    src = batch.src
                    tgt = batch.tgt

                    pmt_msk = batch.pmt_msk
                    states = batch.states
                    ex_idx = batch.ex_idx
                    tgt_idx = batch.tgt_idx

                    mask_src = batch.mask_src
                    mask_tgt = batch.mask_tgt

                    outputs, _ = self.model(src, tgt, mask_src, pmt_msk, ex_idx)
                    init_logps, trans_logps = self.model.trans_logprobs()
                    ext_logps = self.model.external_logprobs()
                    batch_stats = self.loss.mono_compute_loss(batch, outputs, states,
                                                            ex_idx, tgt_idx, mask_tgt,
                                                            init_logps, trans_logps,
                                                            ext_logps)
                else:
                    src = batch.src
                    tgt = batch.tgt
                    segs = batch.segs
                    mask_src = batch.mask_src
                    mask_tgt = batch.mask_tgt

                    outputs, _ = self.model(src, tgt, segs, mask_src, mask_tgt)
                    batch_stats = self.loss.monolithic_compute_loss(batch, outputs)

                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)

            return stats
Exemplo n.º 21
0
    def validate(self, valid_iter_fct, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()
        valid_iter = valid_iter_fct()
        cn = 0
        for _ in valid_iter:
            cn+=1
        best_model_saved = False
        valid_iter = valid_iter_fct()
        with torch.no_grad():
            for batch in tqdm(valid_iter, total=cn):
                src = batch.src
                tgt = batch.tgt
                segs = batch.segs
                clss = batch.clss
                mask_src = batch.mask_src
                mask_tgt = batch.mask_tgt
                mask_cls = batch.mask_cls

                outputs, _ = self.model(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)

                batch_stats = self.loss.monolithic_compute_loss(batch, outputs)
                stats.update(batch_stats)

            self.valid_rgls.append(stats.accuracy())
            if len(self.valid_rgls) > 0:
                if self.max_rl < self.valid_rgls[-1]:
                    self.max_rl = self.valid_rgls[-1]
                    best_model_saved = True

            self._report_step(0, step, valid_stats=stats)
            return stats, best_model_saved
Exemplo n.º 22
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                src = batch.src
                tgt = batch.tgt
                segs = batch.segs
                clss = batch.clss
                mask_src = batch.mask_src
                mask_tgt = batch.mask_tgt
                mask_cls = batch.mask_cls
                # pre
                # outputs, _ = self.model(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)
                # batch_stats = self.loss.monolithic_compute_loss(batch, outputs)
                ### graph
                # ent_list = batch.ent_list
                # rel_list = batch.rel_list
                # adj = batch.adj
                outputs, scores, src_context, graph_context, top_vec, ent_top_vec, _ = self.model(
                    src, tgt, segs, clss, mask_src, mask_tgt, mask_cls, batch)
                batch_stats = self.loss.monolithic_compute_loss(
                    batch, outputs, src_context, graph_context, batch.ent_src,
                    ent_top_vec, self.copy)

                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            with open('./score.txt', 'a+') as f:
                f.write('step: %g \n' % step)
                f.write('xent: %g \n' % stats.ppl())
            return stats
Exemplo n.º 23
0
    def validate(self, valid_iter, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            for batch in valid_iter:
                src = batch.src
                tgt = batch.tgt
                segs = batch.segs
                clss = batch.clss
                alignment = batch.alignment

                mask_src = batch.mask_src
                mask_tgt = batch.mask_tgt
                mask_cls = batch.mask_cls
                mask_alg = batch.mask_alg

                outputs, state, src_mem_bank = self.model(
                    src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)

                batch_stats = self.loss.monolithic_compute_loss(batch, \
                        outputs, \
                        src_mem_bank, \
                        alignment, \
                        mask_alg, \
                        mask_tgt, \
                        saved_attn = state.s2t_attn, \
                        loss_ws=state.loss_ws)

                stats.update(batch_stats)
            self._report_step(0, step, valid_stats=stats)
            return stats
Exemplo n.º 24
0
    def _gradient_calculation(self, true_batchs, normalization, total_stats,
                              report_stats, step):
        self.model.zero_grad()

        for batch in true_batchs:
            outputs, _, topic_loss = self.model(batch)

            tgt_tokens, src_tokens, sents, examples = normalization

            if self.args.topic_model:
                # Topic Model loss
                topic_stats = Statistics(topic_loss=topic_loss.clone().item() /
                                         float(examples))
                topic_loss.div(float(examples)).backward(retain_graph=True)
                total_stats.update(topic_stats)
                report_stats.update(topic_stats)

            # Auto-encoder loss
            abs_stats = self.abs_loss(batch,
                                      outputs,
                                      self.args.generator_shard_size,
                                      tgt_tokens,
                                      retain_graph=False)
            abs_stats.n_docs = len(batch)
            total_stats.update(abs_stats)
            report_stats.update(abs_stats)

        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.n_gpu > 1:
            grads = [
                p.grad.data for p in self.model.parameters()
                if p.requires_grad and p.grad is not None
            ]
            distributed.all_reduce_and_rescale_tensors(grads, float(1))
        for o in self.optims:
            o.step()
Exemplo n.º 25
0
    def train(self,
              train_iter_fct,
              train_steps,
              valid_iter_fct=None,
              valid_steps=-1):
        """
        The main training loops.
        by iterating over training data (i.e. `train_iter_fct`)
        and running validation (i.e. iterating over `valid_iter_fct`

        Args:
            train_iter_fct(function): a function that returns the train
                iterator. e.g. something like
                train_iter_fct = lambda: generator(*args, **kwargs)
            valid_iter_fct(function): same as train_iter_fct, for valid data
            train_steps(int):
            valid_steps(int):
            save_checkpoint_steps(int):

        Return:
            None
        """
        logger.info('Start training...')

        step = self.optims[0]._step + 1
        true_batchs = []
        accum = 0
        tgt_tokens = 0
        src_tokens = 0
        sents = 0
        examples = 0

        train_iter = train_iter_fct()
        total_stats = Statistics()
        report_stats = Statistics()
        self._start_report_manager(start_time=total_stats.start_time)

        while step <= train_steps:

            for i, batch in enumerate(train_iter):
                if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank):

                    true_batchs.append(batch)
                    tgt_tokens += batch.tgt[:, 1:].ne(
                        self.abs_loss.padding_idx).sum().item()
                    src_tokens += batch.src[:, 1:].ne(
                        self.abs_loss.padding_idx).sum().item()
                    sents += batch.src.size(0)
                    examples += batch.tgt.size(0)
                    accum += 1
                    if accum == self.grad_accum_count:
                        if self.n_gpu > 1:
                            tgt_tokens = sum(
                                distributed.all_gather_list(tgt_tokens))
                            src_tokens = sum(
                                distributed.all_gather_list(src_tokens))
                            sents = sum(distributed.all_gather_list(sents))
                            examples = sum(
                                distributed.all_gather_list(examples))

                        normalization = (tgt_tokens, src_tokens, sents,
                                         examples)
                        self._gradient_calculation(true_batchs, normalization,
                                                   total_stats, report_stats,
                                                   step)

                        report_stats = self._maybe_report_training(
                            step, train_steps, self.optims[0].learning_rate,
                            report_stats)

                        true_batchs = []
                        accum = 0
                        src_tokens = 0
                        tgt_tokens = 0
                        sents = 0
                        examples = 0
                        if (step % self.save_checkpoint_steps == 0
                                and self.gpu_rank == 0):
                            self._save(step)
                        step += 1
                        if step > train_steps:
                            break
            train_iter = train_iter_fct()

        return total_stats
Exemplo n.º 26
0
    def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """

        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s)) > 0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        can_path = '%s_step%d.candidate' % (self.args.result_path, step)
        gold_path = '%s_step%d.gold' % (self.args.result_path, step)
        with open(can_path, 'w') as save_pred:
            with open(gold_path, 'w') as save_gold:
                with torch.no_grad():
                    for batch in test_iter:
                        gold = []
                        pred = []
                        if (cal_lead):
                            selected_ids = [list(range(batch.clss.size(1)))
                                            ] * batch.batch_size
                        for i, idx in enumerate(selected_ids):
                            _pred = []
                            if (len(batch.src_str[i]) == 0):
                                continue
                            for j in selected_ids[i][:len(batch.src_str[i])]:
                                if (j >= len(batch.src_str[i])):
                                    continue
                                candidate = batch.src_str[i][j].strip()
                                _pred.append(candidate)

                                if ((not cal_oracle)
                                        and (not self.args.recall_eval)
                                        and len(_pred) == 3):
                                    break

                            _pred = '<q>'.join(_pred)
                            if (self.args.recall_eval):
                                _pred = ' '.join(
                                    _pred.split()
                                    [:len(batch.tgt_str[i].split())])

                            pred.append(_pred)
                            gold.append(batch.tgt_str[i])

                        for i in range(len(gold)):
                            save_gold.write(gold[i].strip() + '\n')
                        for i in range(len(pred)):
                            save_pred.write(pred[i].strip() + '\n')
        if (step != -1 and self.args.report_rouge):
            rouges = test_rouge(self.args.temp_dir, can_path, gold_path)
            logger.info('Rouges at step %d \n%s' %
                        (step, rouge_results_to_str(rouges)))
        self._report_step(0, step, valid_stats=stats)

        return stats
Exemplo n.º 27
0
    def train(self, train_iter_fct):
        """Main training process of MTL.

        Args:
            train_iter_fct (function):
                return a instance of data.data_loader.MetaDataloader.
        """

        logger.info('Start training... (' + str(self.args.maml_type) + ')')
        step = self.optims[0]._step + 1  # resume the step recorded in optims
        true_sup_batchs = []
        true_qry_batchs = []
        accum = 0
        task_accum = 0
        sup_normalization = 0
        qry_normalization = 0

        # Dataloader
        train_iter = train_iter_fct()  # class Dataloader

        # Reporter and Statistics
        report_outer_stats = Statistics()
        report_inner_stats = Statistics()
        self._start_report_manager(start_time=report_outer_stats.start_time)

        # Current only support MAML
        assert self.args.maml_type == 'maml'

        # Make sure the accumulation of gradient is correct
        assert self.args.accum_count == self.args.num_batch_in_task

        while step <= self.args.train_steps:  # NOTE: Outer loop
            for i, (sup_batch, qry_batch) in enumerate(train_iter):
                if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank):

                    # Collect batches (= self.grad_accum_count) as real batch
                    true_sup_batchs.append(sup_batch)
                    true_qry_batchs.append(qry_batch)

                    # Count non-padding words in bathces
                    sup_num_tokens = sup_batch.tgt[:, 1:].ne(
                        self.loss.padding_idx).sum()
                    qry_num_tokens = qry_batch.tgt[:, 1:].ne(
                        self.loss.padding_idx).sum()
                    sup_normalization += sup_num_tokens.item()
                    qry_normalization += qry_num_tokens.item()

                    accum += 1
                    if accum == self.args.num_batch_in_task:
                        task_accum += 1

                        #=============== Inner Update ================
                        # Sum-up non-padding words from multi-GPU
                        if self.n_gpu > 1:
                            sup_normalization = sum(
                                distributed.all_gather_list(sup_normalization))

                        inner_step = 1
                        while inner_step <= self.args.inner_train_steps:  # NOTE: Inner loop

                            # Compute gradient and update
                            self._maml_inner_gradient_accumulation(
                                true_sup_batchs, sup_normalization,
                                report_inner_stats, inner_step, task_accum)

                            # Call self.report_manager to report training process (if reach args.report_every)
                            report_inner_stats = self._maybe_report_inner_training(
                                inner_step, self.args.inner_train_steps,
                                self.optims_inner[task_accum -
                                                  1][0].learning_rate,
                                self.optims_inner[task_accum -
                                                  1][1].learning_rate,
                                report_inner_stats)

                            inner_step += 1
                            if inner_step > self.args.inner_train_steps:
                                break

                        #=============== Outer Update ================

                        # Sum-up non-padding words from multi-GPU
                        if self.n_gpu > 1:
                            qry_normalization = sum(
                                distributed.all_gather_list(qry_normalization))

                        # Compute gradient and update
                        self._maml_outter_gradient_accumulation(
                            true_qry_batchs, qry_normalization,
                            report_outer_stats, step, inner_step, task_accum)

                        if (task_accum == self.args.num_task):
                            # Calculate gradient norm
                            total_norm = 0.0
                            for p in self.model.parameters():
                                if (p.grad is not None):
                                    param_norm = p.grad.data.norm(2)
                                    total_norm += param_norm.item()**2
                            total_norm = total_norm**(1. / 2)

                        #===============================================

                        # Reset
                        true_sup_batchs = []
                        true_qry_batchs = []
                        accum = 0
                        sup_normalization = 0
                        qry_normalization = 0

                if (task_accum == self.args.num_task):

                    # Call self.report_manager to report training process(if reach args.report_every)
                    report_outer_stats = self._maybe_report_training(
                        step, self.args.train_steps,
                        self.optims[0].learning_rate,
                        self.optims[1].learning_rate, report_outer_stats)

                    # Reset
                    task_accum = 0

                    # Save
                    if (step % self.save_checkpoint_steps == 0
                            and self.gpu_rank == 0):
                        self._save(step)

                    # Check steps to stop
                    step += 1
                    if step > self.args.train_steps:
                        break

            # End for an epoch, reload and reset
            train_iter = train_iter_fct()

        self.report_manager.tensorboard_writer.flush(
        )  # force to output the log
Exemplo n.º 28
0
    def validate(self, valid_iter_fct, step=0):
        """Main validation process of MTL.

        Args:
            train_iter_fct (function):
                return a instance of data.data_loader.MetaDataloader.
        """
        logger.info('Start validating...')

        step = 0
        ckpt_step = self.optims[0]._step  # resume the step recorded in optims
        true_sup_batchs = []
        true_qry_batchs = []
        accum = 0
        task_accum = 0
        sup_normalization = 0
        qry_normalization = 0

        # Dataloader
        valid_iter = valid_iter_fct()  # class Dataloader

        # Reporter and Statistics
        report_outer_stats = Statistics()
        report_inner_stats = Statistics()
        self._start_report_manager(start_time=report_outer_stats.start_time)

        # Make sure the accumulation of gradient is correct
        assert self.args.accum_count == self.args.num_batch_in_task

        while step <= self.args.train_steps:

            for i, (sup_batch, qry_batch) in enumerate(valid_iter):
                if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank):

                    # Collect batches (= self.grad_accum_count) as real batch
                    true_sup_batchs.append(sup_batch)
                    true_qry_batchs.append(qry_batch)

                    # Count non-padding words in bathces
                    sup_num_tokens = sup_batch.tgt[:, 1:].ne(
                        self.loss.padding_idx).sum()
                    qry_num_tokens = qry_batch.tgt[:, 1:].ne(
                        self.loss.padding_idx).sum()
                    sup_normalization += sup_num_tokens.item()
                    qry_normalization += qry_num_tokens.item()

                    # Gradient normalize for tasks
                    qry_normalization = qry_normalization * self.args.num_task

                    accum += 1
                    if accum == self.args.num_batch_in_task:
                        task_accum += 1

                        # NOTE: Clear optimizer state
                        self.optims_inner[task_accum -
                                          1][0].optimizer.clear_states()
                        self.optims_inner[task_accum -
                                          1][1].optimizer.clear_states()

                        #=============== Inner Update ================
                        # Sum-up non-padding words from multi-GPU
                        if self.n_gpu > 1:
                            sup_normalization = sum(
                                distributed.all_gather_list(sup_normalization))

                        inner_step = 1
                        while inner_step <= self.args.inner_train_steps:
                            # Compute gradient and update
                            self._maml_inner_gradient_accumulation(
                                true_sup_batchs,
                                sup_normalization,
                                report_inner_stats,
                                inner_step,
                                task_accum,
                                inference_mode=True)

                            # Call self.report_manager to report training process (if reach args.report_every)
                            report_inner_stats = self._maybe_report_inner_training(
                                inner_step, self.args.inner_train_steps,
                                self.optims_inner[task_accum -
                                                  1][0].learning_rate,
                                self.optims_inner[task_accum -
                                                  1][1].learning_rate,
                                report_inner_stats)

                            inner_step += 1
                            if inner_step > self.args.inner_train_steps:
                                break
                        #===============================================

                        #=============== Outer No Update ================
                        self.model.eval()

                        # Calculate loss only, no update for the initialization
                        self._valid(true_qry_batchs, report_outer_stats,
                                    ckpt_step)

                        # Clean fast weight
                        self.model._clean_fast_weights_mode()

                        self.model.train()
                        #===============================================

                        # Reset
                        true_sup_batchs = []
                        true_qry_batchs = []
                        accum = 0
                        sup_normalization = 0
                        qry_normalization = 0

                if (task_accum == self.args.num_task):

                    # Reset
                    task_accum = 0

                    # Check steps to stop
                    step += 1
                    if step > self.args.train_steps:
                        break

            # End for an epoch, reload & reset
            valid_iter = valid_iter_fct()

        # Report average result afer all validation steps
        self._report_step(0, ckpt_step,
                          valid_stats=report_outer_stats)  # first arg is lr
        self.report_manager.tensorboard_writer.flush(
        )  # force to output the log

        return report_outer_stats
Exemplo n.º 29
0
    def train(self,
              train_iter_fct,
              train_steps,
              valid_iter_fct=None,
              valid_steps=-1):
        """
        The main training loops.
        by iterating over training data (i.e. `train_iter_fct`)
        and running validation (i.e. iterating over `valid_iter_fct`

        Args:
            train_iter_fct(function): a function that returns the train
                iterator. e.g. something like
                train_iter_fct = lambda: generator(*args, **kwargs)
            valid_iter_fct(function): same as train_iter_fct, for valid data
            train_steps(int):
            valid_steps(int):
            save_checkpoint_steps(int):

        Return:
            None
        """
        logger.info('Start training...')

        # step =  self.optim._step + 1
        step = self.optims[0]._step + 1

        true_batchs = []
        accum = 0
        normalization = 0
        train_iter = train_iter_fct()

        total_stats = Statistics()
        report_stats = Statistics()
        self._start_report_manager(start_time=total_stats.start_time)
        print(f'Step={step}, Train_steps={train_steps}')
        while step <= train_steps:

            reduce_counter = 0
            for i, batch in enumerate(train_iter):
                if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank):

                    true_batchs.append(batch)
                    num_tokens = batch.tgt[:,
                                           1:].ne(self.loss.padding_idx).sum()
                    normalization += num_tokens.item()
                    accum += 1
                    if accum == self.grad_accum_count:
                        reduce_counter += 1
                        if self.n_gpu > 1:
                            normalization = sum(
                                distributed.all_gather_list(normalization))

                        self._gradient_accumulation(true_batchs, normalization,
                                                    total_stats, report_stats)

                        report_stats = self._maybe_report_training(
                            step, train_steps, self.optims[0].learning_rate,
                            report_stats)

                        true_batchs = []
                        accum = 0
                        normalization = 0
                        if (step % self.save_checkpoint_steps == 0
                                and self.gpu_rank == 0):
                            self._save(step)

                        step += 1
                        if step > train_steps:
                            break
            train_iter = train_iter_fct()

        return total_stats
Exemplo n.º 30
0
    def train(self,
              train_iter_fct,
              train_steps,
              valid_iter_fct=None,
              valid_steps=-1):
        """
        The main training loops.
        by iterating over training data (i.e. `train_iter_fct`)
        and running validation (i.e. iterating over `valid_iter_fct`

        Args:
            train_iter_fct(function): a function that returns the train
                iterator. e.g. something like
                train_iter_fct = lambda: generator(*args, **kwargs)
            valid_iter_fct(function): same as train_iter_fct, for valid data
            train_steps(int):
            valid_steps(int):
            save_checkpoint_steps(int):

        Return:
            None
        """
        logger.info('Start training...')

        # step =  self.optim._step + 1
        step = self.optims[0]._step + 1

        true_batchs = []
        accum = 0
        normalization = 0
        train_iter = train_iter_fct()

        total_stats = Statistics()
        report_stats = Statistics()
        self._start_report_manager(start_time=total_stats.start_time)

        while step <= train_steps:

            reduce_counter = 0
            for i, batch in enumerate(train_iter):
                if self.n_gpu == 0 or (i % self.n_gpu == self.gpu_rank):

                    true_batchs.append(batch)
                    num_tokens = batch.tgt[:,
                                           1:].ne(self.loss.padding_idx).sum()
                    normalization += num_tokens.item()
                    accum += 1
                    if accum == self.grad_accum_count:
                        reduce_counter += 1
                        if self.n_gpu > 1:
                            normalization = sum(
                                distributed.all_gather_list(normalization))

                        self._gradient_accumulation(true_batchs, normalization,
                                                    total_stats, report_stats)

                        report_stats = self._maybe_report_training(
                            step, train_steps, self.optims[0].learning_rate,
                            report_stats)

                        if step % self.args.report_every == 0:
                            self.model.eval()
                            logger.info('Model in set eval state')

                            valid_iter = data_loader.Dataloader(
                                self.args,
                                load_dataset(self.args, 'test', shuffle=False),
                                self.args.batch_size,
                                "cuda",
                                shuffle=False,
                                is_test=True)

                            tokenizer = BertTokenizer.from_pretrained(
                                self.args.model_path, do_lower_case=True)
                            symbols = {
                                'BOS': tokenizer.vocab['[unused1]'],
                                'EOS': tokenizer.vocab['[unused2]'],
                                'PAD': tokenizer.vocab['[PAD]'],
                                'EOQ': tokenizer.vocab['[unused3]']
                            }

                            valid_loss = abs_loss(self.model.generator,
                                                  symbols,
                                                  self.model.vocab_size,
                                                  train=False,
                                                  device="cuda")

                            trainer = build_trainer(self.args, 0, self.model,
                                                    None, valid_loss)
                            stats = trainer.validate(valid_iter, step)
                            self.report_manager.report_step(
                                self.optims[0].learning_rate,
                                step,
                                train_stats=None,
                                valid_stats=stats)
                            self.model.train()

                        true_batchs = []
                        accum = 0
                        normalization = 0
                        if (step % self.save_checkpoint_steps == 0
                                and self.gpu_rank == 0):
                            self._save(step)

                        step += 1
                        if step > train_steps:
                            break
            train_iter = train_iter_fct()

        return total_stats