Пример #1
0
    def testing(self, test_iter, step=0, gen_flag=False, tokenizer=None, info="", write_type=None, output_wrong_pred=False):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()
        wrong_predictions = []

        with torch.no_grad():
            for batch in test_iter:
                src = batch.src
                segs = batch.segs
                mask_src = batch.mask_src
                edges = batch.edges
                node_batch = batch.node_batch

                outputs = self.model(src, segs, mask_src, edges, node_batch, gen_flag=gen_flag)
                batch_stats, loss = self.loss.monolithic_compute_loss(batch, outputs, self.epoch)
                stats.update(batch_stats)

                # write out prediction on different time interval
                if tokenizer:
                    predictions = outputs[0].max(axis=1)[1]
                    
                    sents = tokenizer.batch_decode(src, skip_special_tokens=True)
                    for idx, id_ in enumerate(batch.id):
                        label = batch.label[idx].item()
                        prediction = predictions[idx].item()
                        num_node = len(node_batch[idx])
                        id_ = id_[0].item()
                        
                        out_dict = [[info, prediction, label, "", sents[idx].replace("\t", " ")]]
                        columns = ["exp", "predicted_label", "ground-truth", "generated", "source"]
                        out_df = pd.DataFrame(out_dict, columns=columns)

                        if write_type:
                            if write_type=="a" and os.path.exists(model_path):
                                model_path = pjoin(self.args.savepath, "gen_result/{}.txt".format(id_))
                                ori_df = pd.read_csv(model_path, delimiter="\t")
                                out_df = ori_df.append(out_df)

                            out_df.to_csv(model_path, index=False, sep="\t")

                        if label!=prediction:
                            wrong_predictions.append(id_)

            self._report_step(0, step, test_stats=stats)

            if output_wrong_pred:
                return stats, wrong_predictions
            else:
                return stats
Пример #2
0
    def _stats(self,
               loss,
               loss_det,
               logits,
               label,
               loss_gen=None,
               scores=None,
               target=None):
        """
        Args:
            loss (:obj:`FloatTensor`): the loss computed by the loss criterion.
            scores (:obj:`FloatTensor`): a score for each possible output
            target (:obj:`FloatTensor`): true targets

        Returns:
            :obj:`onmt.utils.Statistics` : statistics for this batch.
        """
        n_docs = len(logits)
        pred = logits.max(1)[1]
        #num_correct_det = pred.eq(label.view(-1)).sum().item()
        #results = evaluationclass(label.view(-1), label.view(-1))

        if self.label_num == 2:
            results = evaluationclass(pred, label.view(-1))
            results = (*results, 0, 0, 0, 0, 0, 0, 0, 0)

        elif self.label_num == 4:
            results = evaluation4class(pred, label.view(-1))

        elif self.label_num == 3:
            results = evaluation3class(pred, label.view(-1))

        if loss_gen is not None:
            pred = scores.max(1)[1]
            non_padding = target.ne(self.padding_idx)
            num_correct_token = pred.eq(target) \
                              .masked_select(non_padding) \
                              .sum() \
                              .item()
            num_non_padding = non_padding.sum().item()

            # build statistic for later update to other statistic
            return Statistics(loss.item(), loss_det.item(), *results, n_docs,
                              loss_gen.item(), num_non_padding,
                              num_correct_token)
        return Statistics(loss.item(), loss_det.item(), *results, n_docs)
Пример #3
0
    def sharded_compute_loss(self,
                             batch,
                             output,
                             shard_size,
                             epoch,
                             normalization=None):
        """Compute the forward loss and backpropagate.  Computation is done
        with shards and optionally truncation for memory efficiency.

        Also supports truncated BPTT for long sequences by taking a
        range in the decoder output sequence to back propagate in.
        Range is from `(cur_trunc, cur_trunc + trunc_size)`.

        Note sharding is an exact efficiency trick to relieve memory
        required for the generation buffers. Truncation is an
        approximate efficiency trick to relieve the memory required
        in the RNN buffers.

        Args:
          batch (batch) : batch of labeled examples
          output (:obj:`FloatTensor`) :
              output of decoder model `[tgt_len x batch x hidden]`
          attns (dict) : dictionary of attention distributions
              `[tgt_len x batch x src_len]`
          cur_trunc (int) : starting position of truncation window
          trunc_size (int) : length of truncation window
          shard_size (int) : maximum number of examples in a shard
          normalization (int) : Loss is divided by this number

        Returns:
            :obj:`onmt.utils.Statistics`: validation loss statistics

        """
        self.epoch = epoch
        # inital a statistic with all zeros
        batch_stats = Statistics()
        shard_state = self._make_shard_state(batch, output)
        for shard in shards(shard_state, shard_size):
            loss, stats = self._compute_loss(normalization, **shard)
            loss.backward()
            batch_stats.update(stats)

        return batch_stats
Пример #4
0
    def exp_pos(self, test_iter, step=0):
        """ For running position experiment, not regular function 
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()

        with torch.no_grad():
            diffs = []
            max_diff = 0
            num_wrong = 0
            for batch in test_iter:
                src = batch.src
                segs = batch.segs
                mask_src = batch.mask_src
                edges = batch.edges
                node_batch = batch.node_batch

                highest = 0
                lowest = 1
                wrong = False
                for node in range(len(node_batch[0])):
                    weight = torch.zeros((len(node_batch[0]))).to(src.device)
                    weight[node] = 1 
                    weight = [weight]
                    outputs = self.model.exp_pos(src, segs, mask_src, weight, edges, node_batch)
                    logit = outputs[0][0]
                    prob = torch.exp(logit)/(1+torch.exp(logit))
                    prob = prob[batch.label.reshape(-1)[0]]

                    # Calculate accuracy
                    if prob < 0.5:
                        wrong = True

                    # Find max and min under all possible possition
                    if prob > highest:
                        highest = prob
                    if prob < lowest:
                        lowest = prob

                if wrong:
                    num_wrong += 1

                print(highest, lowest)
                diff = highest - lowest
                diffs.append(diff)
                if diff > max_diff:
                    max_diff = diff

            error = num_wrong/len(test_iter)
            print(sum(diffs)/len(diffs))
            print(max_diff)
            print('{:.4f}'.format(error))
            with open(pjoin(self.args.savepath, 'exp_pos.txt'), 'a') as f:
                f.write('[test-pos],{:.4f},{:.4f},{:.4f}'.format(max_diff, diff, error))
Пример #5
0
    def _maybe_gather_stats(self, stat):
        """
        Gather statistics in multi-processes cases

        Args:
            stat(:obj:onmt.utils.Statistics): a Statistics object to gather
                or None (it returns None in this case)

        Returns:
            stat: the updated (or unchanged) stat object
        """
        if stat is not None and self.n_gpu > 1:
            return Statistics.all_gather_stats(stat)
        return stat
Пример #6
0
    def validate(self, valid_iter, epoch=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        stats = Statistics()
        losses = []

        with torch.no_grad():
            tqdm_ = tqdm(valid_iter, desc='validating {}'.format(epoch))
            for batch in tqdm_:
                src = batch.src
                segs = batch.segs
                mask_src = batch.mask_src
                edges = batch.edges
                node_batch = batch.node_batch

                if self.args.train_gen:
                    tgt = batch.tgt
                    mask_tgt = batch.mask_tgt
                    outputs = self.model(src, segs, mask_src, edges, node_batch, tgt, mask_tgt)
                    batch_stats, loss = self.loss.monolithic_compute_loss(batch, outputs, self.epoch)
                else:
                    outputs = self.model(src, segs, mask_src, edges, node_batch)
                    batch_stats, loss = self.loss.monolithic_compute_loss(batch, outputs, self.epoch)

                stats.update(batch_stats)
                losses.append(loss.item())

            self._report_step(0, epoch, valid_stats=stats)
            print(sum(losses)/len(losses))

            return stats
Пример #7
0
    def test(self, test_iter, step, cal_lead=False, cal_oracle=False):
        """
        Havn't modified by yunzhu !!!!!
             Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        def _get_ngrams(n, text):
            ngram_set = set()
            text_length = len(text)
            max_index_ngram_start = text_length - n
            for i in range(max_index_ngram_start + 1):
                ngram_set.add(tuple(text[i:i + n]))
            return ngram_set

        def _block_tri(c, p):
            tri_c = _get_ngrams(3, c.split())
            for s in p:
                tri_s = _get_ngrams(3, s.split())
                if len(tri_c.intersection(tri_s))>0:
                    return True
            return False

        if (not cal_lead and not cal_oracle):
            self.model.eval()
        stats = Statistics()

        can_path = '%s_step%d.candidate'%(self.args.result_path,step)
        gold_path = '%s_step%d.gold' % (self.args.result_path, step)
        with open(can_path, 'w') as save_pred:
            with open(gold_path, 'w') as save_gold:
                with torch.no_grad():
                    for batch in test_iter:
                        gold = []
                        pred = []
                        if (cal_lead):
                            selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size
                        for i, idx in enumerate(selected_ids):
                            _pred = []
                            if(len(batch.src_str[i])==0):
                                continue
                            for j in selected_ids[i][:len(batch.src_str[i])]:
                                if(j>=len( batch.src_str[i])):
                                    continue
                                candidate = batch.src_str[i][j].strip()
                                _pred.append(candidate)

                                if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3):
                                    break

                            _pred = '<q>'.join(_pred)
                            if(self.args.recall_eval):
                                _pred = ' '.join(_pred.split()[:len(batch.tgt_str[i].split())])

                            pred.append(_pred)
                            gold.append(batch.tgt_str[i])

                        for i in range(len(gold)):
                            save_gold.write(gold[i].strip()+'\n')
                        for i in range(len(pred)):
                            save_pred.write(pred[i].strip()+'\n')
        if(step!=-1 and self.args.report_rouge):
            rouges = test_rouge(self.args.temp_dir, can_path, gold_path)
            logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges)))
        self._report_step(0, step, valid_stats=stats)

        return stats
Пример #8
0
    def train(self, train_iter, train_steps, message=''):
        """
        The main training loops.
        by iterating over training data (i.e. `train_iter_fct`)
        and running validation (i.e. iterating over `valid_iter
        Args:
            train_iter_fct(function): a function that returns the train
                iterator. e.g. something like
                train_iter_fct = lambda: generator(*args, **kwargs)
            valid_iter: same as train_iter_fct, for valid data
            valid_steps(int):
            save_checkpoint_steps(int):

        Return:
            None
        """
        self.model.train()
        self.epoch += 1
        step =  self.optims[0]._step + 1
        one_iter = len(train_iter)

        true_batchs = []
        accum = 0
        normalization = 0

        total_stats = Statistics()
        report_stats = Statistics()
        self._start_report_manager(start_time=total_stats.start_time)

        reduce_counter = 0

        tqdm_ = tqdm(train_iter, desc=message)
        for i, batch in enumerate(tqdm_):
            #print("batch index {}, 0/1/2: {}/{}/{}\r".format(i, len(np.where(batch.y.numpy()==0)[0]),
            #                                              len(np.where(batch.y.numpy()==1)[0]),
            #                                              len(np.where(batch.y.numpy()==2)[0])), end='')

            true_batchs.append(batch)
            if self.args.train_gen:
                num_tokens = batch.tgt[:, 1:].ne(self.loss.padding_idx).sum()
                normalization += num_tokens.item()
            else:
                normalization=None
            accum += 1
            if accum == self.grad_accum_count:
                reduce_counter += 1
                if self.n_gpu > 1 and normalization is not None:
                        normalization = sum(distributed
                                            .all_gather_list
                                            (normalization))

                self._gradient_accumulation(
                    true_batchs, normalization, total_stats,
                    report_stats)
                report_stats = self._maybe_report_training(
                    step, train_steps,
                    self.optims[0].learning_rate,
                    report_stats)

                true_batchs = []
                accum = 0
                normalization = 0
                step += 1
        #if ((self.epoch+1) % self.save_checkpoint_epoch == 0): #and self.gpu_rank == 0):
        #    self._save(self.epoch)

        return total_stats