Example #1
0
from pytorch_lightning import Trainer

from flash.core.data import download_data
from flash.text import TextClassificationData, TextClassifier

if __name__ == "__main__":

    # 1. Download the data
    download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')

    # 2. Load the model from a checkpoint
    model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")

    # 2a. Classify a few sentences! How was the movie?
    predictions = model.predict([
        "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
        "The worst movie in the history of cinema.",
        "I come from Bulgaria where it 's almost impossible to have a tornado."
        "Very, very afraid"
        "This guy has done a great job with this movie!",
    ])
    print(predictions)

    # 2b. Or generate predictions from a sheet file!
    datamodule = TextClassificationData.from_file(
        predict_file="data/imdb/predict.csv",
        input="review",
    )
    predictions = Trainer().predict(model, datamodule=datamodule)
    print(predictions)
from flash.text import TextClassificationData, TextClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/")

# 2. Load the model from a checkpoint
model = TextClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/text_classification_model.pt")

model.serializer = Labels()

# 2a. Classify a few sentences! How was the movie?
predictions = model.predict([
    "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
    "The worst movie in the history of cinema.",
    "I come from Bulgaria where it 's almost impossible to have a tornado.",
    "Very, very afraid.",
    "This guy has done a great job with this movie!",
])
print(predictions)

# 2b. Or generate predictions from a sheet file!
datamodule = TextClassificationData.from_file(
    predict_file="data/imdb/predict.csv",
    input="review",
    # use the same data pre-processing values we used to predict in 2a
    preprocess=model.preprocess,
)
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)