def validate(self, model: T5FusionInDecoder, val_iter: BucketIterator, optimizer_dict=None,
                 log_results=False):
        """
        Does not compute validation loss for now
        """
        model = model.eval()
        it = tqdm(enumerate(val_iter), total=len(val_iter.data()) // val_iter.batch_size + 1)

        total = 0
        hits = 0
        losslist = []
        if log_results:
            import csv
            model_type = self.config['reader_transformer_type'].replace("/", "_")
            outf = open(f"results/gen_reader_{model_type}.csv", "w", encoding="utf-8")
            csvw = csv.writer(outf, delimiter=',')
            csvw.writerow(["Correct", "Question", "Predicted Answer", "GT Answer", "Input"])
        for i, batch in it:
            batch.src = batch.src[0]
            batch.src_mask = batch.src_mask[0]
            batch.doc_mask = batch.doc_mask[0] if hasattr(batch, "doc_mask") else None

            total += len(batch)
            concatenated_encoder_output, concatenated_encoder_attention = model(input_ids=batch.src,
                                                                                attention_mask=batch.src_mask,
                                                                                encode_only=True)
            concatenated_encoder_output_copy = BaseModelOutputWithPastAndCrossAttentions(
                last_hidden_state=copy.deepcopy(concatenated_encoder_output['last_hidden_state']))
            concatenated_encoder_attention_copy = copy.deepcopy(concatenated_encoder_attention)
            outputs: Seq2SeqLMOutput = model(input_ids=None,
                                             attention_mask=concatenated_encoder_attention_copy,
                                             encoder_outputs=concatenated_encoder_output_copy,
                                             passage_mask=batch.doc_mask,
                                             decoder_input_ids=batch.target[:, :-1].contiguous(),
                                             decoder_attention_mask=batch.target_mask[:, :-1].contiguous())

            lm_logits = outputs.logits
            labels = batch.target[:, 1:].reshape(-1)

            losses = F.cross_entropy(lm_logits.view(-1, get_model(model).config.vocab_size), labels,
                                     reduction='none')
            losslist += losses.tolist()

            # hacky, provide just some tensor as input ids, such that it matches batch dimension 1,
            # do not provide input ids, as they should not be needed (and have pre-concatenation batch dim)
            tokenized_answers = get_model(model).generate(input_ids=concatenated_encoder_attention,
                                                          # num_beams=5,
                                                          # num_return_sequences=5,
                                                          attention_mask=concatenated_encoder_attention,
                                                          encoder_outputs=concatenated_encoder_output,
                                                          decoder_start_token_id=batch.target[0][0])

            predicted_answers = [self.tokenizer.decode(ans, skip_special_tokens=True) for ans in
                                 tokenized_answers]
            for i in range(len(batch)):
                hit = eval_utils.metric_max_over_ground_truths(
                    metric_fn=eval_utils.exact_match_score, prediction=predicted_answers[i],
                    ground_truths=batch.answers[i])
                hits += int(hit)
                if log_results:
                    csvw.writerow([
                        hit,
                        batch.question[i],
                        predicted_answers[i],
                        batch.answers[i],
                        self.tokenizer.decode(batch.src[i])
                    ])

            it.set_description(f"Val Loss: {sum(losslist) / len(losslist):.3f} EM: {hits / total:.3f}")

        EM = hits / total
        logging.info(f"S: {get_model(model).training_steps} Validation Loss: {sum(losslist) / len(losslist)}")
        logging.info(f"Validation EM: {EM}")
        if log_results:
            outf.close()
        if EM > self.best_em and not self.config['test_only']:
            logging.info(f"{EM} ---> New BEST!")
            self.best_em = EM
            serializable_model_name = self.config['reader_transformer_type'].replace("/", "_")
            saveable_model = get_model(model)
            saveable_model.optimizer_state_dict = optimizer_dict
            # Note that model training is fully resumable
            # it contains .optimizer_state_dict and .training_steps (=number of updates)
            saved_name = os.path.join(self.config['save_dir'], f"generative_reader_"
                                                               f"EM{EM:.4f}_"
                                                               f"S{get_model(model).training_steps}_"
                                                               f"M{serializable_model_name}_"
                                                               f"{get_timestamp()}_{socket.gethostname()}")
            self.best_ckpt_name = saved_name
            torch.save(saveable_model, saved_name)
        model = model.train()
        return EM
Beispiel #2
0
    def introspect(self,
                   model: torch.nn.Module,
                   dev_iter: BucketIterator,
                   config: dict,
                   verbose=False,
                   log_results=True):
        train_flag = model.training
        model.eval()

        total_batches = len(dev_iter.data()) // dev_iter.batch_size
        if verbose:
            pbar = tqdm(total=total_batches)
        if log_results:
            csvf, writer = self.init_result_logging()
        examples_so_far = 0
        dev_loss = 0
        total_correct = 0
        total_correct_per_level = Counter()
        total_per_level = defaultdict(lambda: 0)
        total_labels = []
        total_preds = []
        for i, batch in enumerate(dev_iter):
            pred_logits = model(batch)

            text_s = [
                ' '.join(
                    self.tokenizer.convert_ids_to_tokens(
                        batch.text[i].cpu().numpy()))
                for i in range(batch.text.shape[0])
            ]
            print(text_s[0])
            # 0: "support",
            # 1: "comment",
            # 2: "deny",
            # 3: "query"
            branch_levels = [id.split(".", 1)[-1] for id in batch.branch_id]
            for branch_depth in branch_levels:
                total_per_level[branch_depth] += 1
            correct, correct_per_level = self.calculate_correct(
                pred_logits, batch.stance_label, levels=branch_levels)
            total_correct += correct
            total_correct_per_level += correct_per_level

            prefix = "exccomment"
            # works only with batch size 1
            if batch.stance_label != 1 and correct == True:
                c = raw_input('please confirm ')
                if c == "y":
                    self.generate_attention_images(batch, model, pred_logits,
                                                   prefix)
                    sys.exit()

            examples_so_far += len(batch.stance_label)
            if verbose:
                pbar.set_description(
                    f"dev loss: {dev_loss / (i + 1):.4f}, dev acc: {total_correct / examples_so_far:.4f}"
                )
                pbar.update(1)

            maxpreds, argmaxpreds = torch.max(F.softmax(pred_logits, -1),
                                              dim=1)
            total_preds += list(argmaxpreds.cpu().numpy())
            total_labels += list(batch.stance_label.cpu().numpy())
            if log_results:
                text_s = [
                    ' '.join(
                        self.tokenizer.convert_ids_to_tokens(
                            batch.text[i].cpu().numpy()))
                    for i in range(batch.text.shape[0])
                ]
                pred_s = list(argmaxpreds.cpu().numpy())
                target_s = list(batch.stance_label.cpu().numpy())
                correct_s = list(
                    (argmaxpreds == batch.stance_label).cpu().numpy())
                prob_s = [
                    f"{x:.2f}" for x in list(maxpreds.cpu().detach().numpy())
                ]

                assert len(text_s) == len(pred_s) == len(correct_s) == len(
                    target_s) == len(prob_s)
                for i in range(len(text_s)):
                    writer.writerow([
                        correct_s[i], batch.id[i], batch.tweet_id[i],
                        branch_levels[i], map_stance_label_to_s[target_s[i]],
                        map_stance_label_to_s[pred_s[i]], prob_s[i],
                        batch.raw_text[i], text_s[i]
                    ])

        loss, acc = dev_loss / total_batches, total_correct / examples_so_far
        total_acc_per_level = {
            depth: total_correct_per_level.get(depth, 0) / total
            for depth, total in total_per_level.items()
        }
        F1 = metrics.f1_score(total_labels, total_preds, average="macro")
        if log_results:
            self.finalize_results_logging(csvf, loss, F1)
        if train_flag:
            model.train()
        return loss, acc, total_acc_per_level, F1