Пример #1
0
def test_from_csv_multilabel(tmpdir):
    csv_path = csv_data(tmpdir, multilabel=True)
    dm = TextClassificationData.from_csv(
        "sentence",
        ["lab1", "lab2"],
        train_file=csv_path,
        val_file=csv_path,
        test_file=csv_path,
        predict_file=csv_path,
        batch_size=1,
    )

    assert dm.multi_label

    batch = next(iter(dm.train_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.val_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.test_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.predict_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
Пример #2
0
def test_from_csv(tmpdir):
    csv_path = csv_data(tmpdir, multilabel=False)
    dm = TextClassificationData.from_csv(
        "sentence",
        "label",
        train_file=csv_path,
        val_file=csv_path,
        test_file=csv_path,
        predict_file=csv_path,
        batch_size=1,
    )

    batch = next(iter(dm.train_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.val_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.test_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.predict_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
Пример #3
0
def test_from_csv(tmpdir):
    csv_path = csv_data(tmpdir)
    dm = TextClassificationData.from_csv("sentence",
                                         "label",
                                         backbone=TEST_BACKBONE,
                                         train_file=csv_path,
                                         batch_size=1)
    batch = next(iter(dm.train_dataloader()))
    assert batch["labels"].item() in [0, 1]
    assert "input_ids" in batch
Пример #4
0
def from_imdb(
    batch_size: int = 4,
    **data_module_kwargs,
) -> TextClassificationData:
    """Downloads and loads the IMDB sentiment classification data set."""
    download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")
    return TextClassificationData.from_csv(
        "review",
        "sentiment",
        train_file="data/imdb/train.csv",
        val_file="data/imdb/valid.csv",
        batch_size=batch_size,
        **data_module_kwargs,
    )
Пример #5
0
def test_classification(tmpdir):

    csv_path = csv_data(tmpdir)

    data = TextClassificationData.from_csv(
        "sentence",
        "label",
        train_file=csv_path,
        num_workers=0,
        batch_size=2,
    )
    model = TextClassifier(2, TEST_BACKBONE)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model, datamodule=data)
Пример #6
0
def from_toxic(
    val_split: float = 0.1,
    batch_size: int = 4,
    **data_module_kwargs,
) -> TextClassificationData:
    """Downloads and loads the Jigsaw toxic comments data set."""
    download_data("https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "./data")
    return TextClassificationData.from_csv(
        "comment_text",
        ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
        train_file="data/jigsaw_toxic_comments/train.csv",
        val_split=val_split,
        batch_size=batch_size,
        **data_module_kwargs,
    )
Пример #7
0
def from_imdb(
    backbone: str = "prajjwal1/bert-medium",
    batch_size: int = 4,
    num_workers: int = 0,
    **preprocess_kwargs,
) -> TextClassificationData:
    """Downloads and loads the IMDB sentiment classification data set."""
    download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")
    return TextClassificationData.from_csv(
        "review",
        "sentiment",
        train_file="data/imdb/train.csv",
        val_file="data/imdb/valid.csv",
        backbone=backbone,
        batch_size=batch_size,
        num_workers=num_workers,
        **preprocess_kwargs,
    )
Пример #8
0
def from_toxic(
    backbone: str = "unitary/toxic-bert",
    val_split: float = 0.1,
    batch_size: int = 4,
    num_workers: int = 0,
    **preprocess_kwargs,
) -> TextClassificationData:
    """Downloads and loads the Jigsaw toxic comments data set."""
    download_data("https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "./data")
    return TextClassificationData.from_csv(
        "comment_text",
        ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
        train_file="data/jigsaw_toxic_comments/train.csv",
        backbone=backbone,
        val_split=val_split,
        batch_size=batch_size,
        num_workers=num_workers,
        **preprocess_kwargs,
    )
Пример #9
0
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Download the data from the Kaggle Toxic Comment Classification Challenge:
# https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge
download_data(
    "https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip",
    "data/")

# 2. Load the model from a checkpoint
model = TextClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/text_classification_multi_label_model.pt"
)

# 2a. Classify a few sentences! How was the movie?
predictions = model.predict([
    "No, he is an arrogant, self serving, immature idiot. Get it right.",
    "U SUCK HANNAH MONTANA",
    "Would you care to vote? Thx.",
])
print(predictions)

# 2b. Or generate predictions from a whole file!
datamodule = TextClassificationData.from_csv(
    "comment_text",
    predict_file="data/jigsaw_toxic_comments/predict.csv",
)
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
Пример #10
0
from flash.core.classification import Labels
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/")

# 2. Load the model from a checkpoint
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")

model.serializer = Labels()

# 2a. Classify a few sentences! How was the movie?
predictions = model.predict([
    "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
    "The worst movie in the history of cinema.",
    "I come from Bulgaria where it 's almost impossible to have a tornado.",
    "Very, very afraid.",
    "This guy has done a great job with this movie!",
])
print(predictions)

# 2b. Or generate predictions from a sheet file!
datamodule = TextClassificationData.from_csv(
    "review",
    predict_file="data/imdb/predict.csv",
)
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Download the data from the Kaggle Toxic Comment Classification Challenge:
# https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge
download_data(
    "https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip",
    "data/")

# 2. Load the data
datamodule = TextClassificationData.from_csv(
    "comment_text",
    ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
    train_file="data/jigsaw_toxic_comments/train.csv",
    test_file="data/jigsaw_toxic_comments/test.csv",
    predict_file="data/jigsaw_toxic_comments/predict.csv",
    batch_size=16,
    val_split=0.1,
    backbone="unitary/toxic-bert",
)

# 3. Build the model
model = TextClassifier(
    num_classes=datamodule.num_classes,
    multi_label=True,
    metrics=F1(num_classes=datamodule.num_classes),
    backbone="unitary/toxic-bert",
)

# 4. Create the trainer
trainer = flash.Trainer(fast_dev_run=True)
Пример #12
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import flash
from flash.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/")

# 2. Load the data
datamodule = TextClassificationData.from_csv(
    train_file="data/imdb/train.csv",
    val_file="data/imdb/valid.csv",
    test_file="data/imdb/test.csv",
    input_fields="review",
    target_fields="sentiment",
    batch_size=16,
)

# 3. Build the model
model = TextClassifier(num_classes=datamodule.num_classes)

# 4. Create the trainer
trainer = flash.Trainer(fast_dev_run=True)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 6. Test model
trainer.test(model)
import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Create the DataModule
# Data from the Kaggle Toxic Comment Classification Challenge:
# https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge
download_data(
    "https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip",
    "./data")

datamodule = TextClassificationData.from_csv(
    "comment_text",
    ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
    train_file="data/jigsaw_toxic_comments/train.csv",
    val_split=0.1,
    batch_size=4,
)

# 2. Build the task
model = TextClassifier(
    backbone="unitary/toxic-bert",
    labels=datamodule.labels,
    multi_label=datamodule.multi_label,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
Пример #14
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")

datamodule = TextClassificationData.from_csv(
    "review",
    "sentiment",
    train_file="data/imdb/train.csv",
    val_file="data/imdb/valid.csv",
    batch_size=4,
)

# 2. Build the task
model = TextClassifier(backbone="prajjwal1/bert-medium",
                       labels=datamodule.labels)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Classify a few sentences! How was the movie?
datamodule = TextClassificationData.from_lists(
    predict_data=[