def testing(self, test_iter, step=0, gen_flag=False, tokenizer=None, info="", write_type=None, output_wrong_pred=False): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. self.model.eval() stats = Statistics() wrong_predictions = [] with torch.no_grad(): for batch in test_iter: src = batch.src segs = batch.segs mask_src = batch.mask_src edges = batch.edges node_batch = batch.node_batch outputs = self.model(src, segs, mask_src, edges, node_batch, gen_flag=gen_flag) batch_stats, loss = self.loss.monolithic_compute_loss(batch, outputs, self.epoch) stats.update(batch_stats) # write out prediction on different time interval if tokenizer: predictions = outputs[0].max(axis=1)[1] sents = tokenizer.batch_decode(src, skip_special_tokens=True) for idx, id_ in enumerate(batch.id): label = batch.label[idx].item() prediction = predictions[idx].item() num_node = len(node_batch[idx]) id_ = id_[0].item() out_dict = [[info, prediction, label, "", sents[idx].replace("\t", " ")]] columns = ["exp", "predicted_label", "ground-truth", "generated", "source"] out_df = pd.DataFrame(out_dict, columns=columns) if write_type: if write_type=="a" and os.path.exists(model_path): model_path = pjoin(self.args.savepath, "gen_result/{}.txt".format(id_)) ori_df = pd.read_csv(model_path, delimiter="\t") out_df = ori_df.append(out_df) out_df.to_csv(model_path, index=False, sep="\t") if label!=prediction: wrong_predictions.append(id_) self._report_step(0, step, test_stats=stats) if output_wrong_pred: return stats, wrong_predictions else: return stats
def _stats(self, loss, loss_det, logits, label, loss_gen=None, scores=None, target=None): """ Args: loss (:obj:`FloatTensor`): the loss computed by the loss criterion. scores (:obj:`FloatTensor`): a score for each possible output target (:obj:`FloatTensor`): true targets Returns: :obj:`onmt.utils.Statistics` : statistics for this batch. """ n_docs = len(logits) pred = logits.max(1)[1] #num_correct_det = pred.eq(label.view(-1)).sum().item() #results = evaluationclass(label.view(-1), label.view(-1)) if self.label_num == 2: results = evaluationclass(pred, label.view(-1)) results = (*results, 0, 0, 0, 0, 0, 0, 0, 0) elif self.label_num == 4: results = evaluation4class(pred, label.view(-1)) elif self.label_num == 3: results = evaluation3class(pred, label.view(-1)) if loss_gen is not None: pred = scores.max(1)[1] non_padding = target.ne(self.padding_idx) num_correct_token = pred.eq(target) \ .masked_select(non_padding) \ .sum() \ .item() num_non_padding = non_padding.sum().item() # build statistic for later update to other statistic return Statistics(loss.item(), loss_det.item(), *results, n_docs, loss_gen.item(), num_non_padding, num_correct_token) return Statistics(loss.item(), loss_det.item(), *results, n_docs)
def sharded_compute_loss(self, batch, output, shard_size, epoch, normalization=None): """Compute the forward loss and backpropagate. Computation is done with shards and optionally truncation for memory efficiency. Also supports truncated BPTT for long sequences by taking a range in the decoder output sequence to back propagate in. Range is from `(cur_trunc, cur_trunc + trunc_size)`. Note sharding is an exact efficiency trick to relieve memory required for the generation buffers. Truncation is an approximate efficiency trick to relieve the memory required in the RNN buffers. Args: batch (batch) : batch of labeled examples output (:obj:`FloatTensor`) : output of decoder model `[tgt_len x batch x hidden]` attns (dict) : dictionary of attention distributions `[tgt_len x batch x src_len]` cur_trunc (int) : starting position of truncation window trunc_size (int) : length of truncation window shard_size (int) : maximum number of examples in a shard normalization (int) : Loss is divided by this number Returns: :obj:`onmt.utils.Statistics`: validation loss statistics """ self.epoch = epoch # inital a statistic with all zeros batch_stats = Statistics() shard_state = self._make_shard_state(batch, output) for shard in shards(shard_state, shard_size): loss, stats = self._compute_loss(normalization, **shard) loss.backward() batch_stats.update(stats) return batch_stats
def exp_pos(self, test_iter, step=0): """ For running position experiment, not regular function """ # Set model in validating mode. self.model.eval() stats = Statistics() with torch.no_grad(): diffs = [] max_diff = 0 num_wrong = 0 for batch in test_iter: src = batch.src segs = batch.segs mask_src = batch.mask_src edges = batch.edges node_batch = batch.node_batch highest = 0 lowest = 1 wrong = False for node in range(len(node_batch[0])): weight = torch.zeros((len(node_batch[0]))).to(src.device) weight[node] = 1 weight = [weight] outputs = self.model.exp_pos(src, segs, mask_src, weight, edges, node_batch) logit = outputs[0][0] prob = torch.exp(logit)/(1+torch.exp(logit)) prob = prob[batch.label.reshape(-1)[0]] # Calculate accuracy if prob < 0.5: wrong = True # Find max and min under all possible possition if prob > highest: highest = prob if prob < lowest: lowest = prob if wrong: num_wrong += 1 print(highest, lowest) diff = highest - lowest diffs.append(diff) if diff > max_diff: max_diff = diff error = num_wrong/len(test_iter) print(sum(diffs)/len(diffs)) print(max_diff) print('{:.4f}'.format(error)) with open(pjoin(self.args.savepath, 'exp_pos.txt'), 'a') as f: f.write('[test-pos],{:.4f},{:.4f},{:.4f}'.format(max_diff, diff, error))
def _maybe_gather_stats(self, stat): """ Gather statistics in multi-processes cases Args: stat(:obj:onmt.utils.Statistics): a Statistics object to gather or None (it returns None in this case) Returns: stat: the updated (or unchanged) stat object """ if stat is not None and self.n_gpu > 1: return Statistics.all_gather_stats(stat) return stat
def validate(self, valid_iter, epoch=0): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. self.model.eval() stats = Statistics() losses = [] with torch.no_grad(): tqdm_ = tqdm(valid_iter, desc='validating {}'.format(epoch)) for batch in tqdm_: src = batch.src segs = batch.segs mask_src = batch.mask_src edges = batch.edges node_batch = batch.node_batch if self.args.train_gen: tgt = batch.tgt mask_tgt = batch.mask_tgt outputs = self.model(src, segs, mask_src, edges, node_batch, tgt, mask_tgt) batch_stats, loss = self.loss.monolithic_compute_loss(batch, outputs, self.epoch) else: outputs = self.model(src, segs, mask_src, edges, node_batch) batch_stats, loss = self.loss.monolithic_compute_loss(batch, outputs, self.epoch) stats.update(batch_stats) losses.append(loss.item()) self._report_step(0, epoch, valid_stats=stats) print(sum(losses)/len(losses)) return stats
def test(self, test_iter, step, cal_lead=False, cal_oracle=False): """ Havn't modified by yunzhu !!!!! Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. def _get_ngrams(n, text): ngram_set = set() text_length = len(text) max_index_ngram_start = text_length - n for i in range(max_index_ngram_start + 1): ngram_set.add(tuple(text[i:i + n])) return ngram_set def _block_tri(c, p): tri_c = _get_ngrams(3, c.split()) for s in p: tri_s = _get_ngrams(3, s.split()) if len(tri_c.intersection(tri_s))>0: return True return False if (not cal_lead and not cal_oracle): self.model.eval() stats = Statistics() can_path = '%s_step%d.candidate'%(self.args.result_path,step) gold_path = '%s_step%d.gold' % (self.args.result_path, step) with open(can_path, 'w') as save_pred: with open(gold_path, 'w') as save_gold: with torch.no_grad(): for batch in test_iter: gold = [] pred = [] if (cal_lead): selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size for i, idx in enumerate(selected_ids): _pred = [] if(len(batch.src_str[i])==0): continue for j in selected_ids[i][:len(batch.src_str[i])]: if(j>=len( batch.src_str[i])): continue candidate = batch.src_str[i][j].strip() _pred.append(candidate) if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3): break _pred = '<q>'.join(_pred) if(self.args.recall_eval): _pred = ' '.join(_pred.split()[:len(batch.tgt_str[i].split())]) pred.append(_pred) gold.append(batch.tgt_str[i]) for i in range(len(gold)): save_gold.write(gold[i].strip()+'\n') for i in range(len(pred)): save_pred.write(pred[i].strip()+'\n') if(step!=-1 and self.args.report_rouge): rouges = test_rouge(self.args.temp_dir, can_path, gold_path) logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges))) self._report_step(0, step, valid_stats=stats) return stats
def train(self, train_iter, train_steps, message=''): """ The main training loops. by iterating over training data (i.e. `train_iter_fct`) and running validation (i.e. iterating over `valid_iter Args: train_iter_fct(function): a function that returns the train iterator. e.g. something like train_iter_fct = lambda: generator(*args, **kwargs) valid_iter: same as train_iter_fct, for valid data valid_steps(int): save_checkpoint_steps(int): Return: None """ self.model.train() self.epoch += 1 step = self.optims[0]._step + 1 one_iter = len(train_iter) true_batchs = [] accum = 0 normalization = 0 total_stats = Statistics() report_stats = Statistics() self._start_report_manager(start_time=total_stats.start_time) reduce_counter = 0 tqdm_ = tqdm(train_iter, desc=message) for i, batch in enumerate(tqdm_): #print("batch index {}, 0/1/2: {}/{}/{}\r".format(i, len(np.where(batch.y.numpy()==0)[0]), # len(np.where(batch.y.numpy()==1)[0]), # len(np.where(batch.y.numpy()==2)[0])), end='') true_batchs.append(batch) if self.args.train_gen: num_tokens = batch.tgt[:, 1:].ne(self.loss.padding_idx).sum() normalization += num_tokens.item() else: normalization=None accum += 1 if accum == self.grad_accum_count: reduce_counter += 1 if self.n_gpu > 1 and normalization is not None: normalization = sum(distributed .all_gather_list (normalization)) self._gradient_accumulation( true_batchs, normalization, total_stats, report_stats) report_stats = self._maybe_report_training( step, train_steps, self.optims[0].learning_rate, report_stats) true_batchs = [] accum = 0 normalization = 0 step += 1 #if ((self.epoch+1) % self.save_checkpoint_epoch == 0): #and self.gpu_rank == 0): # self._save(self.epoch) return total_stats