predict_file="data/jigsaw_toxic_comments/predict.csv",
    batch_size=16,
    val_split=0.1,
    backbone="unitary/toxic-bert",
)

# 3. Build the model
model = TextClassifier(
    num_classes=datamodule.num_classes,
    multi_label=True,
    metrics=F1(num_classes=datamodule.num_classes),
    backbone="unitary/toxic-bert",
)

# 4. Create the trainer
trainer = flash.Trainer(fast_dev_run=True)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 6. Generate predictions for a few comments!
predictions = model.predict([
    "No, he is an arrogant, self serving, immature idiot. Get it right.",
    "U SUCK HANNAH MONTANA",
    "Would you care to vote? Thx.",
])
print(predictions)

# 7. Save it!
trainer.save_checkpoint("text_classification_multi_label_model.pt")
Exemplo n.º 2
0
import flash
from flash.core.data import download_data
from flash.text import TextClassificationData, TextClassifier
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/")
# 2. Load the data
datamodule = TextClassificationData.from_files(
    train_file="data/imdb/train.csv",
    valid_file="data/imdb/valid.csv",
    test_file="data/imdb/test.csv",
    input="review",
    target="sentiment",
    batch_size=512)
# 3. Build the model
model = TextClassifier(num_classes=datamodule.num_classes)
# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)
# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 6. Test model
trainer.test()
# 7. Predict on new data
predictions = model.predict([
    "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
    "I come from Bulgaria where it 's almost impossible to have a tornado."
    "This guy has done a great job with this movie!",
])
print(predictions)
# 8. Save the model!
trainer.save_checkpoint("text_classification_model.pt")