def test_from_lists_multilabel():
    dm = TextClassificationData.from_lists(
        train_data=TEST_LIST_DATA,
        train_targets=TEST_LIST_TARGETS_MULTILABEL,
        val_data=TEST_LIST_DATA,
        val_targets=TEST_LIST_TARGETS_MULTILABEL,
        test_data=TEST_LIST_DATA,
        test_targets=TEST_LIST_TARGETS_MULTILABEL,
        predict_data=TEST_LIST_DATA,
        batch_size=1,
    )

    assert dm.multi_label

    batch = next(iter(dm.train_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.val_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.test_dataloader()))
    assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]])
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.predict_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
def test_from_lists():
    dm = TextClassificationData.from_lists(
        train_data=TEST_LIST_DATA,
        train_targets=TEST_LIST_TARGETS,
        val_data=TEST_LIST_DATA,
        val_targets=TEST_LIST_TARGETS,
        test_data=TEST_LIST_DATA,
        test_targets=TEST_LIST_TARGETS,
        predict_data=TEST_LIST_DATA,
        batch_size=1,
    )

    batch = next(iter(dm.train_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.val_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.test_dataloader()))
    assert batch[DataKeys.TARGET].item() in [0, 1]
    assert isinstance(batch[DataKeys.INPUT][0], str)

    batch = next(iter(dm.predict_dataloader()))
    assert isinstance(batch[DataKeys.INPUT][0], str)
Example #3
0
def test_predict(tmpdir):
    datamodule = TextClassificationData.from_lists(predict_data=predict_data,
                                                   batch_size=4)
    model = TextEmbedder(backbone=TEST_BACKBONE)

    trainer = flash.Trainer(gpus=torch.cuda.device_count())
    predictions = trainer.predict(model, datamodule=datamodule)
    assert [t.size() for t in predictions[0]
            ] == [torch.Size([384]),
                  torch.Size([384]),
                  torch.Size([384])]
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.text import TextClassificationData, TextEmbedder

# 1. Create the DataModule
datamodule = TextClassificationData.from_lists(
    predict_data=[
        "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
        "The worst movie in the history of cinema.",
        "I come from Bulgaria where it 's almost impossible to have a tornado.",
    ],
    batch_size=4,
)

# 2. Load a previously trained TextEmbedder
model = TextEmbedder(backbone="sentence-transformers/all-MiniLM-L6-v2")

# 3. Generate embeddings for the first 3 graphs
trainer = flash.Trainer(gpus=torch.cuda.device_count())
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
    train_file="data/jigsaw_toxic_comments/train.csv",
    val_split=0.1,
    batch_size=4,
)

# 2. Build the task
model = TextClassifier(
    backbone="unitary/toxic-bert",
    labels=datamodule.labels,
    multi_label=datamodule.multi_label,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Generate predictions for a few comments!
datamodule = TextClassificationData.from_lists(
    predict_data=[
        "No, he is an arrogant, self serving, immature idiot. Get it right.",
        "U SUCK HANNAH MONTANA",
        "Would you care to vote? Thx.",
    ],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("text_classification_multi_label_model.pt")