def validate(self, model: T5FusionInDecoder, val_iter: BucketIterator, optimizer_dict=None, log_results=False): """ Does not compute validation loss for now """ model = model.eval() it = tqdm(enumerate(val_iter), total=len(val_iter.data()) // val_iter.batch_size + 1) total = 0 hits = 0 losslist = [] if log_results: import csv model_type = self.config['reader_transformer_type'].replace("/", "_") outf = open(f"results/gen_reader_{model_type}.csv", "w", encoding="utf-8") csvw = csv.writer(outf, delimiter=',') csvw.writerow(["Correct", "Question", "Predicted Answer", "GT Answer", "Input"]) for i, batch in it: batch.src = batch.src[0] batch.src_mask = batch.src_mask[0] batch.doc_mask = batch.doc_mask[0] if hasattr(batch, "doc_mask") else None total += len(batch) concatenated_encoder_output, concatenated_encoder_attention = model(input_ids=batch.src, attention_mask=batch.src_mask, encode_only=True) concatenated_encoder_output_copy = BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=copy.deepcopy(concatenated_encoder_output['last_hidden_state'])) concatenated_encoder_attention_copy = copy.deepcopy(concatenated_encoder_attention) outputs: Seq2SeqLMOutput = model(input_ids=None, attention_mask=concatenated_encoder_attention_copy, encoder_outputs=concatenated_encoder_output_copy, passage_mask=batch.doc_mask, decoder_input_ids=batch.target[:, :-1].contiguous(), decoder_attention_mask=batch.target_mask[:, :-1].contiguous()) lm_logits = outputs.logits labels = batch.target[:, 1:].reshape(-1) losses = F.cross_entropy(lm_logits.view(-1, get_model(model).config.vocab_size), labels, reduction='none') losslist += losses.tolist() # hacky, provide just some tensor as input ids, such that it matches batch dimension 1, # do not provide input ids, as they should not be needed (and have pre-concatenation batch dim) tokenized_answers = get_model(model).generate(input_ids=concatenated_encoder_attention, # num_beams=5, # num_return_sequences=5, attention_mask=concatenated_encoder_attention, encoder_outputs=concatenated_encoder_output, decoder_start_token_id=batch.target[0][0]) predicted_answers = [self.tokenizer.decode(ans, skip_special_tokens=True) for ans in tokenized_answers] for i in range(len(batch)): hit = eval_utils.metric_max_over_ground_truths( metric_fn=eval_utils.exact_match_score, prediction=predicted_answers[i], ground_truths=batch.answers[i]) hits += int(hit) if log_results: csvw.writerow([ hit, batch.question[i], predicted_answers[i], batch.answers[i], self.tokenizer.decode(batch.src[i]) ]) it.set_description(f"Val Loss: {sum(losslist) / len(losslist):.3f} EM: {hits / total:.3f}") EM = hits / total logging.info(f"S: {get_model(model).training_steps} Validation Loss: {sum(losslist) / len(losslist)}") logging.info(f"Validation EM: {EM}") if log_results: outf.close() if EM > self.best_em and not self.config['test_only']: logging.info(f"{EM} ---> New BEST!") self.best_em = EM serializable_model_name = self.config['reader_transformer_type'].replace("/", "_") saveable_model = get_model(model) saveable_model.optimizer_state_dict = optimizer_dict # Note that model training is fully resumable # it contains .optimizer_state_dict and .training_steps (=number of updates) saved_name = os.path.join(self.config['save_dir'], f"generative_reader_" f"EM{EM:.4f}_" f"S{get_model(model).training_steps}_" f"M{serializable_model_name}_" f"{get_timestamp()}_{socket.gethostname()}") self.best_ckpt_name = saved_name torch.save(saveable_model, saved_name) model = model.train() return EM
def introspect(self, model: torch.nn.Module, dev_iter: BucketIterator, config: dict, verbose=False, log_results=True): train_flag = model.training model.eval() total_batches = len(dev_iter.data()) // dev_iter.batch_size if verbose: pbar = tqdm(total=total_batches) if log_results: csvf, writer = self.init_result_logging() examples_so_far = 0 dev_loss = 0 total_correct = 0 total_correct_per_level = Counter() total_per_level = defaultdict(lambda: 0) total_labels = [] total_preds = [] for i, batch in enumerate(dev_iter): pred_logits = model(batch) text_s = [ ' '.join( self.tokenizer.convert_ids_to_tokens( batch.text[i].cpu().numpy())) for i in range(batch.text.shape[0]) ] print(text_s[0]) # 0: "support", # 1: "comment", # 2: "deny", # 3: "query" branch_levels = [id.split(".", 1)[-1] for id in batch.branch_id] for branch_depth in branch_levels: total_per_level[branch_depth] += 1 correct, correct_per_level = self.calculate_correct( pred_logits, batch.stance_label, levels=branch_levels) total_correct += correct total_correct_per_level += correct_per_level prefix = "exccomment" # works only with batch size 1 if batch.stance_label != 1 and correct == True: c = raw_input('please confirm ') if c == "y": self.generate_attention_images(batch, model, pred_logits, prefix) sys.exit() examples_so_far += len(batch.stance_label) if verbose: pbar.set_description( f"dev loss: {dev_loss / (i + 1):.4f}, dev acc: {total_correct / examples_so_far:.4f}" ) pbar.update(1) maxpreds, argmaxpreds = torch.max(F.softmax(pred_logits, -1), dim=1) total_preds += list(argmaxpreds.cpu().numpy()) total_labels += list(batch.stance_label.cpu().numpy()) if log_results: text_s = [ ' '.join( self.tokenizer.convert_ids_to_tokens( batch.text[i].cpu().numpy())) for i in range(batch.text.shape[0]) ] pred_s = list(argmaxpreds.cpu().numpy()) target_s = list(batch.stance_label.cpu().numpy()) correct_s = list( (argmaxpreds == batch.stance_label).cpu().numpy()) prob_s = [ f"{x:.2f}" for x in list(maxpreds.cpu().detach().numpy()) ] assert len(text_s) == len(pred_s) == len(correct_s) == len( target_s) == len(prob_s) for i in range(len(text_s)): writer.writerow([ correct_s[i], batch.id[i], batch.tweet_id[i], branch_levels[i], map_stance_label_to_s[target_s[i]], map_stance_label_to_s[pred_s[i]], prob_s[i], batch.raw_text[i], text_s[i] ]) loss, acc = dev_loss / total_batches, total_correct / examples_so_far total_acc_per_level = { depth: total_correct_per_level.get(depth, 0) / total for depth, total in total_per_level.items() } F1 = metrics.f1_score(total_labels, total_preds, average="macro") if log_results: self.finalize_results_logging(csvf, loss, F1) if train_flag: model.train() return loss, acc, total_acc_per_level, F1