def test_serve():
    model = TextClassifier(2, TEST_BACKBONE)
    # TODO: Currently only servable once a preprocess and postprocess have been attached
    model._preprocess = TextClassificationPreprocess(backbone=TEST_BACKBONE)
    model._postprocess = TextClassificationPostprocess()
    model.eval()
    model.serve()
def test_jit(tmpdir):
    sample_input = {"input_ids": torch.randint(1000, size=(1, 100))}
    path = os.path.join(tmpdir, "test.pt")

    model = TextClassifier(2, TEST_BACKBONE)
    model.eval()

    # Huggingface bert model only supports `torch.jit.trace` with `strict=False`
    model = torch.jit.trace(model, sample_input, strict=False)

    torch.jit.save(model, path)
    model = torch.jit.load(path)

    out = model(sample_input)["logits"]
    assert isinstance(out, torch.Tensor)
    assert out.shape == torch.Size([1, 2])
Example #3
0
def test_init_train(tmpdir):
    if os.name == "nt":
        # TODO: huggingface stuff timing out on windows
        #
        return True
    model = TextClassifier(2, TEST_BACKBONE)
    train_dl = torch.utils.data.DataLoader(DummyDataset())
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model, train_dl)
Example #4
0
def test_classification(tmpdir):

    csv_path = csv_data(tmpdir)

    data = TextClassificationData.from_csv(
        "sentence",
        "label",
        train_file=csv_path,
        num_workers=0,
        batch_size=2,
    )
    model = TextClassifier(2, TEST_BACKBONE)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model, datamodule=data)
Example #5
0
def test_init_train_enable_ort(tmpdir):
    class TestCallback(Callback):
        def on_train_start(self, trainer: Trainer,
                           pl_module: LightningModule) -> None:
            assert isinstance(pl_module.model, ORTModule)

    model = TextClassifier(2, TEST_BACKBONE, enable_ort=True)
    trainer = Trainer(default_root_dir=tmpdir,
                      fast_dev_run=True,
                      callbacks=TestCallback())
    trainer.fit(
        model,
        train_dataloader=torch.utils.data.DataLoader(DummyDataset()),
        val_dataloaders=torch.utils.data.DataLoader(DummyDataset()),
    )
    trainer.test(model,
                 test_dataloaders=torch.utils.data.DataLoader(DummyDataset()))
Example #6
0
def test_classification(tmpdir):
    if os.name == "nt":
        # TODO: huggingface stuff timing out on windows
        return True

    csv_path = csv_data(tmpdir)

    data = TextClassificationData.from_files(
        backbone=TEST_BACKBONE,
        train_file=csv_path,
        input="sentence",
        target="label",
        num_workers=0,
        batch_size=2,
    )
    model = TextClassifier(2, TEST_BACKBONE)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model, datamodule=data)
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.text import TextClassifier

model = TextClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.7.0/text_classification_model.pt"
)
model.serve()
# limitations under the License.
from flash.text import TextClassificationData, TextClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')

# 2. Load the data
datamodule = TextClassificationData.from_files(
    train_file="data/imdb/train.csv",
    valid_file="data/imdb/valid.csv",
    test_file="data/imdb/test.csv",
    input="review",
    target="sentiment",
    batch_size=512
)

# 3. Build the model
model = TextClassifier(num_classes=datamodule.num_classes)

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, strategy='freeze')

# 6. Test model
trainer.test()

# 7. Save it!
trainer.save_checkpoint("text_classification_model.pt")
# 2. Load the data
datamodule = TextClassificationData.from_csv(
    "comment_text",
    ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
    train_file="data/jigsaw_toxic_comments/train.csv",
    test_file="data/jigsaw_toxic_comments/test.csv",
    predict_file="data/jigsaw_toxic_comments/predict.csv",
    batch_size=16,
    val_split=0.1,
    backbone="unitary/toxic-bert",
)

# 3. Build the model
model = TextClassifier(
    num_classes=datamodule.num_classes,
    multi_label=True,
    metrics=F1(num_classes=datamodule.num_classes),
    backbone="unitary/toxic-bert",
)

# 4. Create the trainer
trainer = flash.Trainer(fast_dev_run=True)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 6. Generate predictions for a few comments!
predictions = model.predict([
    "No, he is an arrogant, self serving, immature idiot. Get it right.",
    "U SUCK HANNAH MONTANA",
    "Would you care to vote? Thx.",
])
def test_load_from_checkpoint_dependency_error():
    with pytest.raises(ModuleNotFoundError,
                       match=re.escape("'lightning-flash[text]'")):
        TextClassifier.load_from_checkpoint("not_a_real_checkpoint.pt")
def test_init_train(tmpdir):
    model = TextClassifier(2, TEST_BACKBONE)
    train_dl = torch.utils.data.DataLoader(DummyDataset())
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model, train_dl)
download_data(
    "https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip",
    "./data")

datamodule = TextClassificationData.from_csv(
    "comment_text",
    ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
    train_file="data/jigsaw_toxic_comments/train.csv",
    val_split=0.1,
    batch_size=4,
)

# 2. Build the task
model = TextClassifier(
    backbone="unitary/toxic-bert",
    labels=datamodule.labels,
    multi_label=datamodule.multi_label,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Generate predictions for a few comments!
datamodule = TextClassificationData.from_lists(
    predict_data=[
        "No, he is an arrogant, self serving, immature idiot. Get it right.",
        "U SUCK HANNAH MONTANA",
        "Would you care to vote? Thx.",
    ],
    batch_size=4,
Example #13
0
def test_serve():
    model = TextClassifier(2, backbone=TEST_BACKBONE)
    model.eval()
    model.serve()
Example #14
0
import flash
from flash.core.data import download_data
from flash.text import TextClassificationData, TextClassifier
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/")
# 2. Load the data
datamodule = TextClassificationData.from_files(
    train_file="data/imdb/train.csv",
    valid_file="data/imdb/valid.csv",
    test_file="data/imdb/test.csv",
    input="review",
    target="sentiment",
    batch_size=512)
# 3. Build the model
model = TextClassifier(num_classes=datamodule.num_classes)
# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)
# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 6. Test model
trainer.test()
# 7. Predict on new data
predictions = model.predict([
    "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
    "I come from Bulgaria where it 's almost impossible to have a tornado."
    "This guy has done a great job with this movie!",
])
print(predictions)
# 8. Save the model!
trainer.save_checkpoint("text_classification_model.pt")
# 1. Create the DataModule
download_data(
    "https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/text_data.zip",
    "./data/")

backbone = "prajjwal1/bert-medium"

datamodule = TextClassificationData.from_labelstudio(
    export_json="data/project.json",
    val_split=0.2,
    backbone=backbone,
)

# 2. Build the task
model = TextClassifier(backbone=backbone, num_classes=datamodule.num_classes)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Classify a few sentences! How was the movie?
datamodule = TextClassificationData.from_lists(predict_data=[
    "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
    "The worst movie in the history of cinema.",
    "I come from Bulgaria where it 's almost impossible to have a tornado.",
])
predictions = trainer.predict(model, datamodule=datamodule)

# 5. Save the model!
trainer.save_checkpoint("text_classification_model.pt")
Example #16
0
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")

datamodule = TextClassificationData.from_csv(
    "review",
    "sentiment",
    train_file="data/imdb/train.csv",
    val_file="data/imdb/valid.csv",
    batch_size=4,
)

# 2. Build the task
model = TextClassifier(backbone="prajjwal1/bert-medium",
                       labels=datamodule.labels)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Classify a few sentences! How was the movie?
datamodule = TextClassificationData.from_lists(
    predict_data=[
        "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
        "The worst movie in the history of cinema.",
        "I come from Bulgaria where it 's almost impossible to have a tornado.",
    ],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")