def test_from_lists_multilabel(): dm = TextClassificationData.from_lists( train_data=TEST_LIST_DATA, train_targets=TEST_LIST_TARGETS_MULTILABEL, val_data=TEST_LIST_DATA, val_targets=TEST_LIST_TARGETS_MULTILABEL, test_data=TEST_LIST_DATA, test_targets=TEST_LIST_TARGETS_MULTILABEL, predict_data=TEST_LIST_DATA, batch_size=1, ) assert dm.multi_label batch = next(iter(dm.train_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) assert isinstance(batch[DataKeys.INPUT][0], str)
def test_from_lists(): dm = TextClassificationData.from_lists( train_data=TEST_LIST_DATA, train_targets=TEST_LIST_TARGETS, val_data=TEST_LIST_DATA, val_targets=TEST_LIST_TARGETS, test_data=TEST_LIST_DATA, test_targets=TEST_LIST_TARGETS, predict_data=TEST_LIST_DATA, batch_size=1, ) batch = next(iter(dm.train_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) assert isinstance(batch[DataKeys.INPUT][0], str)
def test_predict(tmpdir): datamodule = TextClassificationData.from_lists(predict_data=predict_data, batch_size=4) model = TextEmbedder(backbone=TEST_BACKBONE) trainer = flash.Trainer(gpus=torch.cuda.device_count()) predictions = trainer.predict(model, datamodule=datamodule) assert [t.size() for t in predictions[0] ] == [torch.Size([384]), torch.Size([384]), torch.Size([384])]
# # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import flash from flash.text import TextClassificationData, TextEmbedder # 1. Create the DataModule datamodule = TextClassificationData.from_lists( predict_data=[ "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", "The worst movie in the history of cinema.", "I come from Bulgaria where it 's almost impossible to have a tornado.", ], batch_size=4, ) # 2. Load a previously trained TextEmbedder model = TextEmbedder(backbone="sentence-transformers/all-MiniLM-L6-v2") # 3. Generate embeddings for the first 3 graphs trainer = flash.Trainer(gpus=torch.cuda.device_count()) predictions = trainer.predict(model, datamodule=datamodule) print(predictions)
train_file="data/jigsaw_toxic_comments/train.csv", val_split=0.1, batch_size=4, ) # 2. Build the task model = TextClassifier( backbone="unitary/toxic-bert", labels=datamodule.labels, multi_label=datamodule.multi_label, ) # 3. Create the trainer and finetune the model trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Generate predictions for a few comments! datamodule = TextClassificationData.from_lists( predict_data=[ "No, he is an arrogant, self serving, immature idiot. Get it right.", "U SUCK HANNAH MONTANA", "Would you care to vote? Thx.", ], batch_size=4, ) predictions = trainer.predict(model, datamodule=datamodule, output="labels") print(predictions) # 5. Save the model! trainer.save_checkpoint("text_classification_multi_label_model.pt")