Ejemplo n.º 1
0
def test_text_module_not_found_error():
    with pytest.raises(ModuleNotFoundError, match="[text]"):
        TextClassificationData.from_json("sentence",
                                         "lab",
                                         backbone=TEST_BACKBONE,
                                         train_file="",
                                         batch_size=1)
Ejemplo n.º 2
0
def test_from_parquet_multilabel(tmpdir):
    parquet_path = parquet_data(tmpdir, True)
    dm = TextClassificationData.from_parquet(
        "sentence",
        ["lab1", "lab2"],
        train_file=parquet_path,
        val_file=parquet_path,
        test_file=parquet_path,
        predict_file=parquet_path,
        batch_size=1,
    )

    assert dm.multi_label

    batch = next(iter(dm.train_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.val_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.test_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.predict_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
Ejemplo n.º 3
0
def test_from_parquet(tmpdir):
    parquet_path = parquet_data(tmpdir, False)
    dm = TextClassificationData.from_parquet(
        "sentence",
        "lab1",
        train_file=parquet_path,
        val_file=parquet_path,
        test_file=parquet_path,
        predict_file=parquet_path,
        batch_size=1,
    )

    batch = next(iter(dm.train_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.val_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.test_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.predict_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
Ejemplo n.º 4
0
def test_from_json_with_field_multilabel(tmpdir):
    json_path = json_data_with_field(tmpdir, multilabel=True)
    dm = TextClassificationData.from_json(
        "sentence",
        ["lab1", "lab2"],
        train_file=json_path,
        val_file=json_path,
        test_file=json_path,
        predict_file=json_path,
        batch_size=1,
        field="data",
    )

    assert dm.multi_label

    batch = next(iter(dm.train_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.val_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.test_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.predict_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
Ejemplo n.º 5
0
def test_from_json_with_field(tmpdir):
    json_path = json_data_with_field(tmpdir, multilabel=False)
    dm = TextClassificationData.from_json(
        "sentence",
        "lab",
        train_file=json_path,
        val_file=json_path,
        test_file=json_path,
        predict_file=json_path,
        batch_size=1,
        field="data",
    )

    batch = next(iter(dm.train_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.val_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.test_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.predict_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
Ejemplo n.º 6
0
def test_from_hf_datasets():
    TEST_HF_DATASET_DATA = Dataset.from_pandas(TEST_DATA_FRAME_DATA)
    dm = TextClassificationData.from_hf_datasets(
        "sentence",
        "lab1",
        train_hf_dataset=TEST_HF_DATASET_DATA,
        val_hf_dataset=TEST_HF_DATASET_DATA,
        test_hf_dataset=TEST_HF_DATASET_DATA,
        predict_hf_dataset=TEST_HF_DATASET_DATA,
        batch_size=1,
    )

    batch = next(iter(dm.train_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.val_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.test_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.predict_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
Ejemplo n.º 7
0
def test_from_hf_datasets_multilabel():
    TEST_HF_DATASET_DATA_MULTILABEL = Dataset.from_pandas(TEST_DATA_FRAME_DATA_MULTILABEL)
    dm = TextClassificationData.from_hf_datasets(
        "sentence",
        ["lab1", "lab2"],
        train_hf_dataset=TEST_HF_DATASET_DATA_MULTILABEL,
        val_hf_dataset=TEST_HF_DATASET_DATA_MULTILABEL,
        test_hf_dataset=TEST_HF_DATASET_DATA_MULTILABEL,
        predict_hf_dataset=TEST_HF_DATASET_DATA_MULTILABEL,
        batch_size=1,
    )

    assert dm.multi_label

    batch = next(iter(dm.train_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.val_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.test_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.predict_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
Ejemplo n.º 8
0
def test_from_lists():
    dm = TextClassificationData.from_lists(
        train_data=TEST_LIST_DATA,
        train_targets=TEST_LIST_TARGETS,
        val_data=TEST_LIST_DATA,
        val_targets=TEST_LIST_TARGETS,
        test_data=TEST_LIST_DATA,
        test_targets=TEST_LIST_TARGETS,
        predict_data=TEST_LIST_DATA,
        batch_size=1,
    )

    batch = next(iter(dm.train_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.val_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.test_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.predict_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
Ejemplo n.º 9
0
def test_from_lists_multilabel():
    dm = TextClassificationData.from_lists(
        train_data=TEST_LIST_DATA,
        train_targets=TEST_LIST_TARGETS_MULTILABEL,
        val_data=TEST_LIST_DATA,
        val_targets=TEST_LIST_TARGETS_MULTILABEL,
        test_data=TEST_LIST_DATA,
        test_targets=TEST_LIST_TARGETS_MULTILABEL,
        predict_data=TEST_LIST_DATA,
        batch_size=1,
    )

    assert dm.multi_label

    batch = next(iter(dm.train_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.val_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.test_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.predict_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
Ejemplo n.º 10
0
def test_from_data_frame():
    dm = TextClassificationData.from_data_frame(
        "sentence",
        "lab1",
        train_data_frame=TEST_DATA_FRAME_DATA,
        val_data_frame=TEST_DATA_FRAME_DATA,
        test_data_frame=TEST_DATA_FRAME_DATA,
        predict_data_frame=TEST_DATA_FRAME_DATA,
        batch_size=1,
    )

    batch = next(iter(dm.train_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.val_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.test_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.predict_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
Ejemplo n.º 11
0
def test_from_json_with_field(tmpdir):
    json_path = json_data_with_field(tmpdir)
    dm = TextClassificationData.from_json(
        "sentence", "lab", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data"
    )
    batch = next(iter(dm.train_dataloader()))
    assert batch["labels"].item() in [0, 1]
    assert "input_ids" in batch
Ejemplo n.º 12
0
def test_from_csv(tmpdir):
    csv_path = csv_data(tmpdir)
    dm = TextClassificationData.from_files(backbone=TEST_BACKBONE,
                                           train_file=csv_path,
                                           input="sentence",
                                           target="label",
                                           batch_size=1)
    batch = next(iter(dm.train_dataloader()))
    assert batch["labels"].item() in [0, 1]
    assert "input_ids" in batch
Ejemplo n.º 13
0
def test_predict(tmpdir):
    datamodule = TextClassificationData.from_lists(predict_data=predict_data,
                                                   batch_size=4)
    model = TextEmbedder(backbone=TEST_BACKBONE)

    trainer = flash.Trainer(gpus=torch.cuda.device_count())
    predictions = trainer.predict(model, datamodule=datamodule)
    assert [t.size() for t in predictions[0]
            ] == [torch.Size([384]),
                  torch.Size([384]),
                  torch.Size([384])]
Ejemplo n.º 14
0
def test_from_csv(tmpdir):
    if os.name == "nt":
        # TODO: huggingface stuff timing out on windows
        return True
    csv_path = csv_data(tmpdir)
    dm = TextClassificationData.from_files(backbone=TEST_BACKBONE,
                                           train_file=csv_path,
                                           input="sentence",
                                           target="label",
                                           batch_size=1)
    batch = next(iter(dm.train_dataloader()))
    assert batch["labels"].item() in [0, 1]
    assert "input_ids" in batch
Ejemplo n.º 15
0
def from_imdb(
    batch_size: int = 4,
    **data_module_kwargs,
) -> TextClassificationData:
    """Downloads and loads the IMDB sentiment classification data set."""
    download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")
    return TextClassificationData.from_csv(
        "review",
        "sentiment",
        train_file="data/imdb/train.csv",
        val_file="data/imdb/valid.csv",
        batch_size=batch_size,
        **data_module_kwargs,
    )
Ejemplo n.º 16
0
def test_classification(tmpdir):

    csv_path = csv_data(tmpdir)

    data = TextClassificationData.from_csv(
        "sentence",
        "label",
        train_file=csv_path,
        num_workers=0,
        batch_size=2,
    )
    model = TextClassifier(2, TEST_BACKBONE)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model, datamodule=data)
Ejemplo n.º 17
0
def from_toxic(
    val_split: float = 0.1,
    batch_size: int = 4,
    **data_module_kwargs,
) -> TextClassificationData:
    """Downloads and loads the Jigsaw toxic comments data set."""
    download_data("https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "./data")
    return TextClassificationData.from_csv(
        "comment_text",
        ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
        train_file="data/jigsaw_toxic_comments/train.csv",
        val_split=val_split,
        batch_size=batch_size,
        **data_module_kwargs,
    )
Ejemplo n.º 18
0
def test_test_valid(tmpdir):
    csv_path = csv_data(tmpdir)
    dm = TextClassificationData.from_csv("sentence",
                                         "label",
                                         backbone=TEST_BACKBONE,
                                         train_file=csv_path,
                                         val_file=csv_path,
                                         test_file=csv_path,
                                         batch_size=1)
    batch = next(iter(dm.val_dataloader()))
    assert batch["labels"].item() in [0, 1]
    assert "input_ids" in batch

    batch = next(iter(dm.test_dataloader()))
    assert batch["labels"].item() in [0, 1]
    assert "input_ids" in batch
Ejemplo n.º 19
0
def from_imdb(
    backbone: str = "prajjwal1/bert-medium",
    batch_size: int = 4,
    num_workers: int = 0,
    **preprocess_kwargs,
) -> TextClassificationData:
    """Downloads and loads the IMDB sentiment classification data set."""
    download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")
    return TextClassificationData.from_csv(
        "review",
        "sentiment",
        train_file="data/imdb/train.csv",
        val_file="data/imdb/valid.csv",
        backbone=backbone,
        batch_size=batch_size,
        num_workers=num_workers,
        **preprocess_kwargs,
    )
Ejemplo n.º 20
0
def test_classification(tmpdir):
    if os.name == "nt":
        # TODO: huggingface stuff timing out on windows
        return True

    csv_path = csv_data(tmpdir)

    data = TextClassificationData.from_files(
        backbone=TEST_BACKBONE,
        train_file=csv_path,
        input="sentence",
        target="label",
        num_workers=0,
        batch_size=2,
    )
    model = TextClassifier(2, TEST_BACKBONE)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model, datamodule=data)
Ejemplo n.º 21
0
def from_toxic(
    backbone: str = "unitary/toxic-bert",
    val_split: float = 0.1,
    batch_size: int = 4,
    num_workers: int = 0,
    **preprocess_kwargs,
) -> TextClassificationData:
    """Downloads and loads the Jigsaw toxic comments data set."""
    download_data("https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "./data")
    return TextClassificationData.from_csv(
        "comment_text",
        ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
        train_file="data/jigsaw_toxic_comments/train.csv",
        backbone=backbone,
        val_split=val_split,
        batch_size=batch_size,
        num_workers=num_workers,
        **preprocess_kwargs,
    )
Ejemplo n.º 22
0
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.text import TextClassificationData, TextEmbedder

# 1. Create the DataModule
datamodule = TextClassificationData.from_lists(
    predict_data=[
        "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
        "The worst movie in the history of cinema.",
        "I come from Bulgaria where it 's almost impossible to have a tornado.",
    ],
    batch_size=4,
)

# 2. Load a previously trained TextEmbedder
model = TextEmbedder(backbone="sentence-transformers/all-MiniLM-L6-v2")

# 3. Generate embeddings for the first 3 graphs
trainer = flash.Trainer(gpus=torch.cuda.device_count())
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.text import TextClassificationData, TextClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')

# 2. Load the data
datamodule = TextClassificationData.from_files(
    train_file="data/imdb/train.csv",
    valid_file="data/imdb/valid.csv",
    test_file="data/imdb/test.csv",
    input="review",
    target="sentiment",
    batch_size=512
)

# 3. Build the model
model = TextClassifier(num_classes=datamodule.num_classes)

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, strategy='freeze')

# 6. Test model
trainer.test()
Ejemplo n.º 24
0
from pytorch_lightning import Trainer

from flash.core.data import download_data
from flash.text import TextClassificationData, TextClassifier

if __name__ == "__main__":

    # 1. Download the data
    download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')

    # 2. Load the model from a checkpoint
    model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")

    # 2a. Classify a few sentences! How was the movie?
    predictions = model.predict([
        "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
        "The worst movie in the history of cinema.",
        "I come from Bulgaria where it 's almost impossible to have a tornado."
        "Very, very afraid"
        "This guy has done a great job with this movie!",
    ])
    print(predictions)

    # 2b. Or generate predictions from a sheet file!
    datamodule = TextClassificationData.from_file(
        predict_file="data/imdb/predict.csv",
        input="review",
    )
    predictions = Trainer().predict(model, datamodule=datamodule)
    print(predictions)
Ejemplo n.º 25
0
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Download the data from the Kaggle Toxic Comment Classification Challenge:
# https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge
download_data(
    "https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip",
    "data/")

# 2. Load the model from a checkpoint
model = TextClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/text_classification_multi_label_model.pt"
)

# 2a. Classify a few sentences! How was the movie?
predictions = model.predict([
    "No, he is an arrogant, self serving, immature idiot. Get it right.",
    "U SUCK HANNAH MONTANA",
    "Would you care to vote? Thx.",
])
print(predictions)

# 2b. Or generate predictions from a whole file!
datamodule = TextClassificationData.from_csv(
    "comment_text",
    predict_file="data/jigsaw_toxic_comments/predict.csv",
)
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
Ejemplo n.º 26
0
from flash.core.classification import Labels
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/")

# 2. Load the model from a checkpoint
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")

model.serializer = Labels()

# 2a. Classify a few sentences! How was the movie?
predictions = model.predict([
    "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
    "The worst movie in the history of cinema.",
    "I come from Bulgaria where it 's almost impossible to have a tornado.",
    "Very, very afraid.",
    "This guy has done a great job with this movie!",
])
print(predictions)

# 2b. Or generate predictions from a sheet file!
datamodule = TextClassificationData.from_csv(
    "review",
    predict_file="data/imdb/predict.csv",
)
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Download the data from the Kaggle Toxic Comment Classification Challenge:
# https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge
download_data(
    "https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip",
    "data/")

# 2. Load the data
datamodule = TextClassificationData.from_csv(
    "comment_text",
    ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
    train_file="data/jigsaw_toxic_comments/train.csv",
    test_file="data/jigsaw_toxic_comments/test.csv",
    predict_file="data/jigsaw_toxic_comments/predict.csv",
    batch_size=16,
    val_split=0.1,
    backbone="unitary/toxic-bert",
)

# 3. Build the model
model = TextClassifier(
    num_classes=datamodule.num_classes,
    multi_label=True,
    metrics=F1(num_classes=datamodule.num_classes),
    backbone="unitary/toxic-bert",
)

# 4. Create the trainer
trainer = flash.Trainer(fast_dev_run=True)
import flash
from flash.core.data.utils import download_data
from flash.core.integrations.labelstudio.visualizer import launch_app
from flash.text import TextClassificationData, TextClassifier

# 1. Create the DataModule
download_data(
    "https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/text_data.zip",
    "./data/")

backbone = "prajjwal1/bert-medium"

datamodule = TextClassificationData.from_labelstudio(
    export_json="data/project.json",
    val_split=0.2,
    backbone=backbone,
)

# 2. Build the task
model = TextClassifier(backbone=backbone, num_classes=datamodule.num_classes)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Classify a few sentences! How was the movie?
datamodule = TextClassificationData.from_lists(predict_data=[
    "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
    "The worst movie in the history of cinema.",
    "I come from Bulgaria where it 's almost impossible to have a tornado.",
])
Ejemplo n.º 29
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import flash
from flash.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/")

# 2. Load the data
datamodule = TextClassificationData.from_csv(
    train_file="data/imdb/train.csv",
    val_file="data/imdb/valid.csv",
    test_file="data/imdb/test.csv",
    input_fields="review",
    target_fields="sentiment",
    batch_size=16,
)

# 3. Build the model
model = TextClassifier(num_classes=datamodule.num_classes)

# 4. Create the trainer
trainer = flash.Trainer(fast_dev_run=True)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 6. Test model
trainer.test(model)
import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Create the DataModule
# Data from the Kaggle Toxic Comment Classification Challenge:
# https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge
download_data(
    "https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip",
    "./data")

datamodule = TextClassificationData.from_csv(
    "comment_text",
    ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
    train_file="data/jigsaw_toxic_comments/train.csv",
    val_split=0.1,
    batch_size=4,
)

# 2. Build the task
model = TextClassifier(
    backbone="unitary/toxic-bert",
    labels=datamodule.labels,
    multi_label=datamodule.multi_label,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")