predict_file="data/jigsaw_toxic_comments/predict.csv", batch_size=16, val_split=0.1, backbone="unitary/toxic-bert", ) # 3. Build the model model = TextClassifier( num_classes=datamodule.num_classes, multi_label=True, metrics=F1(num_classes=datamodule.num_classes), backbone="unitary/toxic-bert", ) # 4. Create the trainer trainer = flash.Trainer(fast_dev_run=True) # 5. Fine-tune the model trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 6. Generate predictions for a few comments! predictions = model.predict([ "No, he is an arrogant, self serving, immature idiot. Get it right.", "U SUCK HANNAH MONTANA", "Would you care to vote? Thx.", ]) print(predictions) # 7. Save it! trainer.save_checkpoint("text_classification_multi_label_model.pt")
import flash from flash.core.data import download_data from flash.text import TextClassificationData, TextClassifier # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/") # 2. Load the data datamodule = TextClassificationData.from_files( train_file="data/imdb/train.csv", valid_file="data/imdb/valid.csv", test_file="data/imdb/test.csv", input="review", target="sentiment", batch_size=512) # 3. Build the model model = TextClassifier(num_classes=datamodule.num_classes) # 4. Create the trainer. Run once on data trainer = flash.Trainer(max_epochs=1) # 5. Fine-tune the model trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 6. Test model trainer.test() # 7. Predict on new data predictions = model.predict([ "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", "I come from Bulgaria where it 's almost impossible to have a tornado." "This guy has done a great job with this movie!", ]) print(predictions) # 8. Save the model! trainer.save_checkpoint("text_classification_model.pt")