예제 #1
0
def test_is_running():
    """Test if perplexity is running normal"""
    if version.parse(transformers_version) >= version.parse("3.0.0"):
        return

    tok = AutoTokenizer.from_pretrained("distilbert-base-uncased")
    model = AutoModelWithLMHead.from_pretrained("distilbert-base-uncased")
    dataset = LanguageModelingDataset(texts, tok)
    collate_fn = DataCollatorForLanguageModeling(tok).collate_batch
    dataloader = torch.utils.data.DataLoader(dataset, collate_fn=collate_fn)
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

    runner = HuggingFaceRunner()
    runner.train(
        model=model,
        optimizer=optimizer,
        loaders={"train": dataloader},
        callbacks={
            "optimizer": dl.OptimizerCallback(),
            "perplexity": PerplexityMetricCallback(),
        },
        check=True,
    )
def test_tokenizer_type_error():
    """Test if tokenizer neither hf nor string"""
    tok = lambda x: x
    dataset = LanguageModelingDataset(texts, tok)  # noqa: F841
def test_exception_with_sort():
    """Test lazy=True sort=True case"""
    tok = AutoTokenizer.from_pretrained("bert-base-uncased")
    dataset = LanguageModelingDataset(  # noqa: F841
        texts, tok, lazy=True, sort=True)
def test_tokenizer_tokenizer():
    """Test initialization with tokenizer"""
    tok = AutoTokenizer.from_pretrained("bert-base-uncased")
    dataset = LanguageModelingDataset(texts, tok)
    assert dataset[0] is not None
    assert len(dataset) == 2
def test_tokenizer_str():
    """Test initialization with string"""
    dataset = LanguageModelingDataset(texts, "bert-base-uncased")
    assert dataset[0] is not None
    assert len(dataset) == 2
예제 #6
0
def test_runner():
    """Test that runner executes"""
    train_df = pd.read_csv("data/train.csv")
    valid_df = pd.read_csv("data/valid.csv")
    teacher_config = AutoConfig.from_pretrained("bert-base-uncased",
                                                output_hidden_states=True,
                                                output_logits=True)
    teacher = BertForMaskedLM.from_pretrained("bert-base-uncased",
                                              config=teacher_config)

    student_config = AutoConfig.from_pretrained(
        "distilbert-base-uncased",
        output_hidden_states=True,
        output_logits=True,
    )
    student = DistilBertForMaskedLM.from_pretrained("distilbert-base-uncased",
                                                    config=student_config)

    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    train_dataset = LanguageModelingDataset(train_df["text"], tokenizer)
    valid_dataset = LanguageModelingDataset(valid_df["text"], tokenizer)

    collate_fn = DataCollatorForLanguageModeling(tokenizer)
    train_dataloader = DataLoader(train_dataset,
                                  collate_fn=collate_fn,
                                  batch_size=2)
    valid_dataloader = DataLoader(valid_dataset,
                                  collate_fn=collate_fn,
                                  batch_size=2)
    loaders = {"train": train_dataloader, "valid": valid_dataloader}

    callbacks = {
        "masked_lm_loss":
        MaskedLanguageModelCallback(),
        "mse_loss":
        MSELossCallback(),
        "cosine_loss":
        CosineLossCallback(),
        "kl_div_loss":
        KLDivLossCallback(),
        "loss":
        MetricAggregationCallback(
            prefix="loss",
            mode="weighted_sum",
            metrics={
                "cosine_loss": 1.0,
                "masked_lm_loss": 1.0,
                "kl_div_loss": 1.0,
                "mse_loss": 1.0,
            },
        ),
        "optimizer":
        dl.OptimizerCallback(),
        "perplexity":
        PerplexityMetricCallbackDistillation(),
    }

    model = torch.nn.ModuleDict({"teacher": teacher, "student": student})
    runner = DistilMLMRunner()
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
    runner.train(
        model=model,
        optimizer=optimizer,
        loaders=loaders,
        verbose=True,
        check=True,
        callbacks=callbacks,
    )
    assert True
예제 #7
0
def test_dataset():
    """Test number of tokens"""
    dataset = LanguageModelingDataset(["Hello, world"])
    output_dict = dataset[0]
    assert output_dict["attention_mask"].sum() == 5