from pytorch_lightning import Trainer from flash.core.data import download_data from flash.text import TextClassificationData, TextClassifier if __name__ == "__main__": # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/') # 2. Load the model from a checkpoint model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt") # 2a. Classify a few sentences! How was the movie? predictions = model.predict([ "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", "The worst movie in the history of cinema.", "I come from Bulgaria where it 's almost impossible to have a tornado." "Very, very afraid" "This guy has done a great job with this movie!", ]) print(predictions) # 2b. Or generate predictions from a sheet file! datamodule = TextClassificationData.from_file( predict_file="data/imdb/predict.csv", input="review", ) predictions = Trainer().predict(model, datamodule=datamodule) print(predictions)
from flash.text import TextClassificationData, TextClassifier # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/") # 2. Load the model from a checkpoint model = TextClassifier.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/text_classification_model.pt") model.serializer = Labels() # 2a. Classify a few sentences! How was the movie? predictions = model.predict([ "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", "The worst movie in the history of cinema.", "I come from Bulgaria where it 's almost impossible to have a tornado.", "Very, very afraid.", "This guy has done a great job with this movie!", ]) print(predictions) # 2b. Or generate predictions from a sheet file! datamodule = TextClassificationData.from_file( predict_file="data/imdb/predict.csv", input="review", # use the same data pre-processing values we used to predict in 2a preprocess=model.preprocess, ) predictions = Trainer().predict(model, datamodule=datamodule) print(predictions)