Esempio n. 1
0
def train(ev_inf: InferenceNet, train_Xy, val_Xy, test_Xy, inference_vectorizer, epochs=10, batch_size=16, shuffle=True):
    # we sort these so batches all have approximately the same length (ish), which decreases the 
    # average amount of padding needed, and thus total number of steps in training.
    if not shuffle:
        train_Xy.sort(key=lambda x: len(x['article']))
        val_Xy.sort(key=lambda x: len(x['article']))
        test_Xy.sort(key=lambda x: len(x['article']))
    print("Using {} training examples, {} validation examples, {} testing examples".format(len(train_Xy), len(val_Xy), len(test_Xy)))
    most_common = stats.mode([_get_majority_label(inst) for inst in train_Xy])[0][0]

    best_val_model = None
    best_val_f1 = float('-inf')
    if USE_CUDA:
        ev_inf = ev_inf.cuda()

    optimizer = optim.Adam(ev_inf.parameters())
    criterion = nn.CrossEntropyLoss(reduction='sum')  # sum (not average) of the batch losses.

    # TODO add epoch timing information here
    epochs_since_improvement = 0
    val_metrics = {
        "val_acc": [],
        "val_p": [],
        "val_r": [],
        "val_f1": [],
        "val_loss": [],
        'train_loss': [],
        'val_aucs': [],
        'train_aucs': [],
        'val_entropies': [],
        'val_evidence_token_mass': [],
        'val_evidence_token_err': [],
        'train_entropies': [],
        'train_evidence_token_mass': [],
        'train_evidence_token_err': []
    }
    for epoch in range(epochs):
        if epochs_since_improvement > 10:
            print("Exiting early due to no improvement on validation after 10 epochs.")
            break
        if shuffle:
            random.shuffle(train_Xy)

        epoch_loss = 0
        for i in range(0, len(train_Xy), batch_size):
            instances = train_Xy[i:i+batch_size]
            ys = torch.cat([_get_y_vec(inst['y'], as_vec=False) for inst in instances], dim=0)
            # TODO explain the use of padding here
            unk_idx = int(inference_vectorizer.str_to_idx[SimpleInferenceVectorizer.PAD])
            articles, Is, Cs, Os = [PaddedSequence.autopad([torch.LongTensor(inst[x]) for inst in instances], batch_first=True, padding_value=unk_idx) for x in ['article', 'I', 'C', 'O']]
            optimizer.zero_grad()
            if USE_CUDA:
                articles, Is, Cs, Os = articles.cuda(), Is.cuda(), Cs.cuda(), Os.cuda()
                ys = ys.cuda()
            verbose_attn = (epoch == epochs - 1 and i == 0) or (epoch == 0 and i == 0)
            if verbose_attn:
                print("Training attentions:")
            tags = ev_inf(articles, Is, Cs, Os, batch_size=len(instances), verbose_attn=verbose_attn)
            loss = criterion(tags, ys)
            #if loss.item() != loss.item():
            #    import pdb; pdb.set_trace()
            epoch_loss += loss.item()
            loss.backward()
            optimizer.step()
        val_metrics['train_loss'].append(epoch_loss)

        with torch.no_grad():
            verbose_attn_to_batches = set([0,1,2,3,4]) if epoch == epochs - 1 or epoch == 0 else False
            if verbose_attn_to_batches:
                print("Validation attention:")
            # make_preds runs in eval mode
            val_y, val_y_hat = make_preds(ev_inf, val_Xy, batch_size, inference_vectorizer, verbose_attn_to_batches=verbose_attn_to_batches)
            val_loss = criterion(val_y_hat, val_y.squeeze())
            y_hat = to_int_preds(val_y_hat)

            if epoch == 0:
                dummy_preds = [most_common] * len(val_y)
                dummy_acc = accuracy_score(val_y.cpu(), dummy_preds)
                val_metrics["baseline_val_acc"] = dummy_acc
                p, r, f1, _ = precision_recall_fscore_support(val_y.cpu(), dummy_preds, labels=None, beta=1, average='macro', pos_label=1, warn_for=('f-score',), sample_weight=None)
                val_metrics['p_dummy'] = p
                val_metrics['r_dummy'] = r
                val_metrics['f_dummy'] = f1

                print("val dummy accuracy: {:.3f}".format(dummy_acc))
                print("classification report for dummy on val: ")
                print(classification_report(val_y.cpu(), dummy_preds))
                print("\n\n")

            acc = accuracy_score(val_y.cpu(), y_hat)
            val_metrics["val_acc"].append(acc)
            val_loss = val_loss.cpu().item()
            val_metrics["val_loss"].append(val_loss)
           
            # f1 = f1_score(val_y, y_hat, average="macro")
            p, r, f1, _ = precision_recall_fscore_support(val_y.cpu(), y_hat, labels=None, beta=1, average='macro', pos_label=1, warn_for=('f-score',), sample_weight=None)
            val_metrics["val_f1"].append(f1)
            val_metrics["val_p"].append(p)
            val_metrics["val_r"].append(r)

            if ev_inf.article_encoder.use_attention:
                train_auc, train_entropies, train_evidence_token_masses, train_evidence_token_err = evaluate_model_attention_distribution(ev_inf, train_Xy, cuda=USE_CUDA, compute_attention_diagnostics=True)
                val_auc, val_entropies, val_evidence_token_masses, val_evidence_token_err = evaluate_model_attention_distribution(ev_inf, val_Xy, cuda=USE_CUDA, compute_attention_diagnostics=True)
                print("train auc: {:.3f}, entropy: {:.3f}, evidence mass: {:.3f}, err: {:.3f}".format(train_auc, train_entropies, train_evidence_token_masses, train_evidence_token_err))
                print("val auc: {:.3f}, entropy: {:.3f}, evidence mass: {:.3f}, err: {:.3f}".format(val_auc, val_entropies, val_evidence_token_masses, val_evidence_token_err))
            else:
                train_auc, train_entropies, train_evidence_token_masses, train_evidence_token_err = "", "", "", ""
                val_auc, val_entropies, val_evidence_token_masses, val_evidence_token_err = "", "", "", ""
            val_metrics['train_aucs'].append(train_auc)
            val_metrics['train_entropies'].append(train_entropies)
            val_metrics['train_evidence_token_mass'].append(train_evidence_token_masses)
            val_metrics['train_evidence_token_err'].append(train_evidence_token_err)
            val_metrics['val_aucs'].append(val_auc)
            val_metrics['val_entropies'].append(val_entropies)
            val_metrics['val_evidence_token_mass'].append(val_evidence_token_masses)
            val_metrics['val_evidence_token_err'].append(val_evidence_token_err)
            if f1 > best_val_f1:
                print("New best model at {} with val f1 {:.3f}".format(epoch, f1))
                best_val_f1 = f1
                best_val_model = copy.deepcopy(ev_inf)
                epochs_since_improvement = 0
            else:
                epochs_since_improvement += 1

            #if val_loss != val_loss or epoch_loss != epoch_loss:
            #    import pdb; pdb.set_trace()

            print("epoch {}. train loss: {}; val loss: {}; val acc: {:.3f}".format(
                epoch, epoch_loss, val_loss, acc))
       
            print(classification_report(val_y.cpu(), y_hat))
            print("val macro f1: {0:.3f}".format(f1))
            print("\n\n")

    val_metrics['best_val_f1'] = best_val_f1
    with torch.no_grad():
        print("Test attentions:")
        verbose_attn_to_batches = set([0,1,2,3,4])
        # make_preds runs in eval mode
        test_y, test_y_hat = make_preds(best_val_model, test_Xy, batch_size, inference_vectorizer, verbose_attn_to_batches=verbose_attn_to_batches)
        test_loss = criterion(test_y_hat, test_y.squeeze())
        y_hat = to_int_preds(test_y_hat)
        final_test_preds = zip([t['a_id'] for t in test_Xy], [t['p_id'] for t in test_Xy], y_hat)

        acc = accuracy_score(test_y.cpu(), y_hat)
        val_metrics["test_acc"] = acc
        test_loss = test_loss.cpu().item()
        val_metrics["test_loss"] = test_loss

        # f1 = f1_score(test_y, y_hat, average="macro")
        p, r, f1, _ = precision_recall_fscore_support(test_y.cpu(), y_hat, labels=None, beta=1, average='macro', pos_label=1, warn_for=('f-score',), sample_weight=None)
        val_metrics["test_f1"] = f1
        val_metrics["test_p"] = p
        val_metrics["test_r"] = r
        if ev_inf.article_encoder.use_attention:
            test_auc, test_entropies, test_evidence_token_masses, test_evidence_token_err = evaluate_model_attention_distribution(best_val_model, test_Xy, cuda=USE_CUDA, compute_attention_diagnostics=True)
            print("test auc: {:.3f}, , entropy: {:.3f}, kl_to_uniform {:.3f}".format(test_auc, test_entropies, test_evidence_token_masses))
        else:
            test_auc, test_entropies, test_evidence_token_masses, test_evidence_token_err = "", "", "", ""
        val_metrics['test_auc'] = test_auc
        val_metrics['test_entropy'] = test_entropies
        val_metrics['test_evidence_token_mass'] = test_evidence_token_masses
        val_metrics['test_evidence_token_err'] = test_evidence_token_err

        print("test loss: {}; test acc: {:.3f}".format(test_loss, acc))

        print(classification_report(test_y.cpu(), y_hat))
        print("test macro f1: {}".format(f1))
        print("\n\n")

    return best_val_model, inference_vectorizer, train_Xy, val_Xy, val_metrics, final_test_preds
def run(real_train_Xy, real_val_Xy, real_test_Xy, inference_vectorizer, mangle_method, config, cuda=USE_CUDA, determinize=False):
    random.seed(177)
    if determinize:
        torch.manual_seed(360)
        torch.backends.cudnn.deterministic = True
        np.random.seed(2115)
    shuffle = False
    print("Running config {}".format(config))
    if config.no_pretrained_word_embeddings:
        num_embeddings = len(inference_vectorizer.idx_to_str)
        embedding_dim = 200
        init_word_embeddings = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, padding_idx=inference_vectorizer.str_to_idx[inference_vectorizer.PAD], _weight=torch.FloatTensor((num_embeddings, embedding_dim)))
    else:
        init_word_embeddings = None
    initial_model = InferenceNet(inference_vectorizer, ICO_encoder=config.ico_encoder, article_encoder=config.article_encoder, attention_over_article_tokens=config.attn, condition_attention=config.cond_attn, tokenwise_attention=config.tokenwise_attention, tune_embeddings=config.tune_embeddings, init_embeddings=init_word_embeddings)
    if cuda:
        initial_model = initial_model.cuda()
    attn_metrics = {
        # note this loss is computed throughout the epoch, but the val loss is computed at the end of the epoch
        'attn_train_losses': [],
        'pretrain_attn_train_auc': [],
        'pretrain_attn_train_token_masses': [],
        'pretrain_attn_train_token_masses_err': [],
        'pretrain_attn_train_entropies': [],
        'attn_val_losses': [],  # TODO eventually rename this for consistency with the other metrics
        'pretrain_attn_val_auc_all': [],  # TODO eventually remove _all from this metric name
        'pretrain_attn_val_token_masses': [],
        'pretrain_attn_val_token_masses_err': [],
        'pretrain_attn_val_entropies': []
    }
    train_Xy, val_Xy, test_Xy = mangle_method(real_train_Xy, real_val_Xy, real_test_Xy, inference_vectorizer)
    if config.attn and config.pretrain_attention:
        print("pre-training attention")
        if config.pretrain_attention == 'pretrain_attention_to_match_span':
            ev_inf, attn_metrics = pretrain_attention_to_match_span(train_Xy, val_Xy, initial_model, epochs=config.attn_epochs, batch_size=config.attn_batch_size, tokenwise_attention=config.tokenwise_attention, cuda=cuda, attention_acceptance=config.attention_acceptance)
        elif config.pretrain_attention == 'pretrain_attention_with_concatenated_spans':
            ev_inf, attn_metrics = pretrain_attention_with_concatenated_spans(train_Xy, val_Xy, initial_model, epochs=config.attn_epochs, batch_size=config.attn_batch_size, tokenwise_attention=config.tokenwise_attention, cuda=cuda, attention_acceptance=config.attention_acceptance)
        elif config.pretrain_attention == 'pretrain_attention_with_random_spans':
            ev_inf, attn_metrics = pretrain_attention_with_random_spans(train_Xy, val_Xy, initial_model, epochs=config.attn_epochs, batch_size=config.attn_batch_size, tokenwise_attention=config.tokenwise_attention, cuda=cuda, attention_acceptance=config.attention_acceptance)
        elif config.pretrain_attention == 'pretrain_tokenwise_attention':
            ev_inf, attn_metrics = pretrain_tokenwise_attention(train_Xy, val_Xy, initial_model, epochs=config.attn_epochs, batch_size=config.attn_batch_size, tokenwise_attention=config.tokenwise_attention, cuda=cuda, attention_acceptance=config.attention_acceptance)
        elif config.pretrain_attention == 'pretrain_tokenwise_attention_balanced':
            ev_inf, attn_metrics = pretrain_tokenwise_attention_balanced(train_Xy, val_Xy, initial_model, epochs=config.attn_epochs, batch_size=config.attn_batch_size, tokenwise_attention=config.tokenwise_attention, cuda=cuda, attention_acceptance=config.attention_acceptance)
        elif config.pretrain_attention == 'pretrain_max_evidence_attention':
            ev_inf, attn_metrics = pretrain_max_evidence_attention(train_Xy, val_Xy, initial_model, epochs=config.attn_epochs, batch_size=config.attn_batch_size, tokenwise_attention=config.tokenwise_attention, cuda=cuda, attention_acceptance=config.attention_acceptance)
        else:
            raise ValueError("Unknown pre-training configuration {}".format(config.pretrain_attention))
    else:
        ev_inf = initial_model

    best_model, _, _, _, val_metrics, final_test_preds = train(ev_inf, train_Xy, val_Xy, test_Xy, inference_vectorizer, batch_size=config.batch_size, epochs=config.epochs, shuffle=shuffle)
    if config.attn and config.article_sections == 'all' and config.data_config == 'vanilla':
        final_train_auc = evaluate_model_attention_distribution(ev_inf, train_Xy, cuda=cuda)
        final_val_auc = evaluate_model_attention_distribution(ev_inf, val_Xy, cuda=cuda)
        final_test_auc = evaluate_model_attention_distribution(ev_inf, test_Xy, cuda=cuda)
    else:
        final_train_auc = ""
        final_val_auc = ""
        final_test_auc = ""

    val_metrics['final_train_auc'] = final_train_auc
    val_metrics['final_val_auc'] = final_val_auc
    val_metrics['final_test_auc'] = final_test_auc

    return best_model, val_metrics, attn_metrics, final_test_preds