def test_from_csv_multilabel(tmpdir): csv_path = csv_data(tmpdir, multilabel=True) dm = TextClassificationData.from_csv( "sentence", ["lab1", "lab2"], train_file=csv_path, val_file=csv_path, test_file=csv_path, predict_file=csv_path, batch_size=1, ) assert dm.multi_label batch = next(iter(dm.train_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) assert isinstance(batch[DataKeys.INPUT][0], str)
def test_from_csv(tmpdir): csv_path = csv_data(tmpdir, multilabel=False) dm = TextClassificationData.from_csv( "sentence", "label", train_file=csv_path, val_file=csv_path, test_file=csv_path, predict_file=csv_path, batch_size=1, ) batch = next(iter(dm.train_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) assert isinstance(batch[DataKeys.INPUT][0], str)
def test_from_csv(tmpdir): csv_path = csv_data(tmpdir) dm = TextClassificationData.from_csv("sentence", "label", backbone=TEST_BACKBONE, train_file=csv_path, batch_size=1) batch = next(iter(dm.train_dataloader())) assert batch["labels"].item() in [0, 1] assert "input_ids" in batch
def from_imdb( batch_size: int = 4, **data_module_kwargs, ) -> TextClassificationData: """Downloads and loads the IMDB sentiment classification data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/") return TextClassificationData.from_csv( "review", "sentiment", train_file="data/imdb/train.csv", val_file="data/imdb/valid.csv", batch_size=batch_size, **data_module_kwargs, )
def test_classification(tmpdir): csv_path = csv_data(tmpdir) data = TextClassificationData.from_csv( "sentence", "label", train_file=csv_path, num_workers=0, batch_size=2, ) model = TextClassifier(2, TEST_BACKBONE) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model, datamodule=data)
def from_toxic( val_split: float = 0.1, batch_size: int = 4, **data_module_kwargs, ) -> TextClassificationData: """Downloads and loads the Jigsaw toxic comments data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "./data") return TextClassificationData.from_csv( "comment_text", ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"], train_file="data/jigsaw_toxic_comments/train.csv", val_split=val_split, batch_size=batch_size, **data_module_kwargs, )
def from_imdb( backbone: str = "prajjwal1/bert-medium", batch_size: int = 4, num_workers: int = 0, **preprocess_kwargs, ) -> TextClassificationData: """Downloads and loads the IMDB sentiment classification data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/") return TextClassificationData.from_csv( "review", "sentiment", train_file="data/imdb/train.csv", val_file="data/imdb/valid.csv", backbone=backbone, batch_size=batch_size, num_workers=num_workers, **preprocess_kwargs, )
def from_toxic( backbone: str = "unitary/toxic-bert", val_split: float = 0.1, batch_size: int = 4, num_workers: int = 0, **preprocess_kwargs, ) -> TextClassificationData: """Downloads and loads the Jigsaw toxic comments data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "./data") return TextClassificationData.from_csv( "comment_text", ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"], train_file="data/jigsaw_toxic_comments/train.csv", backbone=backbone, val_split=val_split, batch_size=batch_size, num_workers=num_workers, **preprocess_kwargs, )
from flash.core.data.utils import download_data from flash.text import TextClassificationData, TextClassifier # 1. Download the data from the Kaggle Toxic Comment Classification Challenge: # https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge download_data( "https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "data/") # 2. Load the model from a checkpoint model = TextClassifier.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/text_classification_multi_label_model.pt" ) # 2a. Classify a few sentences! How was the movie? predictions = model.predict([ "No, he is an arrogant, self serving, immature idiot. Get it right.", "U SUCK HANNAH MONTANA", "Would you care to vote? Thx.", ]) print(predictions) # 2b. Or generate predictions from a whole file! datamodule = TextClassificationData.from_csv( "comment_text", predict_file="data/jigsaw_toxic_comments/predict.csv", ) predictions = Trainer().predict(model, datamodule=datamodule) print(predictions)
from flash.core.classification import Labels from flash.core.data.utils import download_data from flash.text import TextClassificationData, TextClassifier # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/") # 2. Load the model from a checkpoint model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt") model.serializer = Labels() # 2a. Classify a few sentences! How was the movie? predictions = model.predict([ "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", "The worst movie in the history of cinema.", "I come from Bulgaria where it 's almost impossible to have a tornado.", "Very, very afraid.", "This guy has done a great job with this movie!", ]) print(predictions) # 2b. Or generate predictions from a sheet file! datamodule = TextClassificationData.from_csv( "review", predict_file="data/imdb/predict.csv", ) predictions = Trainer().predict(model, datamodule=datamodule) print(predictions)
import flash from flash.core.data.utils import download_data from flash.text import TextClassificationData, TextClassifier # 1. Download the data from the Kaggle Toxic Comment Classification Challenge: # https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge download_data( "https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "data/") # 2. Load the data datamodule = TextClassificationData.from_csv( "comment_text", ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"], train_file="data/jigsaw_toxic_comments/train.csv", test_file="data/jigsaw_toxic_comments/test.csv", predict_file="data/jigsaw_toxic_comments/predict.csv", batch_size=16, val_split=0.1, backbone="unitary/toxic-bert", ) # 3. Build the model model = TextClassifier( num_classes=datamodule.num_classes, multi_label=True, metrics=F1(num_classes=datamodule.num_classes), backbone="unitary/toxic-bert", ) # 4. Create the trainer trainer = flash.Trainer(fast_dev_run=True)
# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import flash from flash.data.utils import download_data from flash.text import TextClassificationData, TextClassifier # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "data/") # 2. Load the data datamodule = TextClassificationData.from_csv( train_file="data/imdb/train.csv", val_file="data/imdb/valid.csv", test_file="data/imdb/test.csv", input_fields="review", target_fields="sentiment", batch_size=16, ) # 3. Build the model model = TextClassifier(num_classes=datamodule.num_classes) # 4. Create the trainer trainer = flash.Trainer(fast_dev_run=True) # 5. Fine-tune the model trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 6. Test model trainer.test(model)
import flash from flash.core.data.utils import download_data from flash.text import TextClassificationData, TextClassifier # 1. Create the DataModule # Data from the Kaggle Toxic Comment Classification Challenge: # https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge download_data( "https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "./data") datamodule = TextClassificationData.from_csv( "comment_text", ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"], train_file="data/jigsaw_toxic_comments/train.csv", val_split=0.1, batch_size=4, ) # 2. Build the task model = TextClassifier( backbone="unitary/toxic-bert", labels=datamodule.labels, multi_label=datamodule.multi_label, ) # 3. Create the trainer and finetune the model trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import flash from flash.core.data.utils import download_data from flash.text import TextClassificationData, TextClassifier # 1. Create the DataModule download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/") datamodule = TextClassificationData.from_csv( "review", "sentiment", train_file="data/imdb/train.csv", val_file="data/imdb/valid.csv", batch_size=4, ) # 2. Build the task model = TextClassifier(backbone="prajjwal1/bert-medium", labels=datamodule.labels) # 3. Create the trainer and finetune the model trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Classify a few sentences! How was the movie? datamodule = TextClassificationData.from_lists( predict_data=[