def test_train_struct_for_transformer(model):
    path_to_train = 'data/sample/en_ewt-ud-train.conllu'
    path_to_valid = 'data/sample/en_ewt-ud-dev.conllu'
    nr_epochs = 3

    print(f"Loading in {model}")
    transformer, transformer_tokenizer = make_pretrained_transformer_and_tokenizer(
        model)

    train_dataloader, valid_dataloader = make_struct_dataloaders(
        path_to_train,
        path_to_valid,
        feature_model=transformer,
        feature_model_tokenizer=transformer_tokenizer,
    )

    _, losses, acc = train_struct(
        train_dataloader,
        valid_dataloader,
        nr_epochs=nr_epochs,
        struct_emb_dim=768,
        struct_lr=10e-4,
        struct_rank=64,
    )

    # Let's test that the first loss is higher than the last loss
    assert losses[0] > losses[-1], "Loss does not seem to improve"

    # Let's test that the accuracy is lower
    assert acc[0] < acc[-1], "Accuracy does not seem to improve"
def test_train_struct_for_LSTM():
    path_to_train = 'data/sample/en_ewt-ud-train.conllu'
    path_to_valid = 'data/sample/en_ewt-ud-dev.conllu'
    nr_epochs = 3

    lstm, lstm_tokenizer = make_pretrained_lstm_and_tokenizer()

    train_dataloader, valid_dataloader = make_struct_dataloaders(
        path_to_train,
        path_to_valid,
        feature_model=lstm,
        feature_model_tokenizer=lstm_tokenizer,
    )

    probe, losses, acc = train_struct(
        train_dataloader,
        valid_dataloader,
        nr_epochs=nr_epochs,
        struct_emb_dim=lstm.nhid,
        struct_lr=10e-4,
        struct_rank=64,
    )

    # Let's test that the first loss is higher than the last loss
    assert losses[0] > losses[-1], "Loss does not seem to improve"

    # Let's test that the accuracy is lower
    assert acc[0] < acc[-1], "Accuracy does not seem to improve"
def test_dep_control_task_training():
    path_to_train = 'data/sample/en_ewt-ud-train.conllu'
    path_to_valid = 'data/sample/en_ewt-ud-dev.conllu'
    nr_epochs = 3

    transformer, transformer_tokenizer = make_pretrained_transformer_and_tokenizer('distilgpt2')

    # Read all corpora and extract into
    all_corpora = parse_all_corpora(True)
    corrupted_dep_vocab = create_corrupted_dep_vocab(all_corpora)

    train_dataloader, valid_dataloader = make_struct_dataloaders(
        path_to_train,
        path_to_valid,
        feature_model=transformer,
        feature_model_tokenizer=transformer_tokenizer,
        use_dependencies=True,
        use_corrupted=True,
        corrupted_vocab=corrupted_dep_vocab
    )

    probe, losses, acc, = train_dep_parsing(
        train_dataloader,
        valid_dataloader,
        768,
        64,
        10e-4,
    )

    assert losses[0] > losses[-1], "Loss did not decrease over the training"
    assert acc[0] < acc[-1], "Accuracy did not increase over the training"
def test_dep_regular_training():
    path_to_train = 'data/sample/en_ewt-ud-train.conllu'
    path_to_valid = 'data/sample/en_ewt-ud-dev.conllu'
    nr_epochs = 3

    transformer, transformer_tokenizer = make_pretrained_transformer_and_tokenizer('distilgpt2')

    train_dataloader, valid_dataloader = make_struct_dataloaders(
        path_to_train,
        path_to_valid,
        feature_model=transformer,
        feature_model_tokenizer=transformer_tokenizer,
        use_dependencies=True
    )

    probe, losses, acc, = train_dep_parsing(
        train_dataloader,
        valid_dataloader,
        768,
        64,
        10e-4,
    )

    assert losses[0] > losses[-1], "Loss did not decrease over the training"
    assert acc[0] < acc[-1], "Accuracy did not increase over the training"
def test_dep_dataloader_returns_corrupted_idxs():
    path_to_train = 'data/sample/en_ewt-ud-train.conllu'
    path_to_valid = 'data/sample/en_ewt-ud-dev.conllu'
    nr_epochs = 3

    # Read all corpora and extract into
    all_corpora = parse_all_corpora(True)
    corrupted_dep_vocab = create_corrupted_dep_vocab(all_corpora)

    transformer, transformer_tokenizer = make_pretrained_transformer_and_tokenizer('distilgpt2')

    train_dataloader, _ = make_struct_dataloaders(
        path_to_train,
        path_to_valid,
        feature_model=transformer,
        feature_model_tokenizer=transformer_tokenizer,
        use_dependencies=True,
        use_corrupted=True,
        corrupted_vocab=corrupted_dep_vocab
    )

    # Sample of training
    train_sample = next(iter(train_dataloader))
    for train_item in train_sample:
        _, parent_edges = train_item

        for idx, parent_tensor in enumerate(parent_edges):
            parent = parent_tensor.item()
            assert parent == idx or parent==-1 or parent == 0 or parent == len(parent_edges), "Not corrupted labels"
def test_dep_dataloader_returns_parent():
    path_to_train = 'data/sample/en_ewt-ud-train.conllu'
    path_to_valid = 'data/sample/en_ewt-ud-dev.conllu'
    nr_epochs = 3

    transformer, transformer_tokenizer = make_pretrained_transformer_and_tokenizer('distilgpt2')

    train_dataloader, _ = make_struct_dataloaders(
        path_to_train,
        path_to_valid,
        feature_model=transformer,
        feature_model_tokenizer=transformer_tokenizer,
        use_dependencies=True
    )

    train_sample = next(iter(train_dataloader))

    for train_item in train_sample:
        _, parent_edges = train_item

        root_nodes = [i for i in parent_edges if i == -1]
        assert len(root_nodes) == 1, f"Encountered root_nodes of size {len(root_nodes)}"
from data_tools.datasets import ProbingDataset
import numpy as np

# %%
from runners.trainers import train_struct
from models.model_inits import make_pretrained_lstm_and_tokenizer, make_pretrained_transformer_and_tokenizer
from data_tools.dataloaders import make_struct_dataloaders

# We will train the structural probe if `config.will_train_structural_probe` allows it.
if config.will_train_structural_probe:
    print("-- Training: Structural Probe --")

    print("Loading in data for structural probe!")
    train_dataloader, valid_dataloader = make_struct_dataloaders(
        path_to_train=config.path_to_data_train,
        path_to_valid=config.path_to_data_valid,
        feature_model=model,
        feature_model_tokenizer=w2i,
    )

    struct_probe, losses, uuas = train_struct(
        train_dataloader,
        valid_dataloader,
        nr_epochs=config.struct_probe_train_epoch,
        struct_emb_dim=config.feature_model_dimensionality,
        struct_lr=config.struct_probe_lr,
        struct_rank=config.struct_probe_rank,
    )

    struct_probe_results = {
        'probe_valid_losses': losses,
        'probe_valid_uuas_scores': uuas