def test_classification_serializers():
    example_output = torch.tensor([-0.1, 0.2, 0.3])  # 3 classes
    labels = ['class_1', 'class_2', 'class_3']

    assert torch.allclose(torch.tensor(Logits().serialize(example_output)),
                          example_output)
    assert torch.allclose(
        torch.tensor(Probabilities().serialize(example_output)),
        torch.softmax(example_output, -1))
    assert Classes().serialize(example_output) == 2
    assert Labels(labels).serialize(example_output) == 'class_3'
def test_multilabel(tmpdir):

    num_classes = 4
    ds = DummyMultiLabelDataset(num_classes)
    model = ImageClassifier(num_classes, multi_label=True, serializer=Probabilities(multi_label=True))
    train_dl = torch.utils.data.DataLoader(ds, batch_size=2)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.finetune(model, train_dl, strategy="freeze_unfreeze")
    image, label = ds[0][DefaultDataKeys.INPUT], ds[0][DefaultDataKeys.TARGET]
    predictions = model.predict([{DefaultDataKeys.INPUT: image}])
    assert (torch.tensor(predictions) > 1).sum() == 0
    assert (torch.tensor(predictions) < 0).sum() == 0
    assert len(predictions[0]) == num_classes == len(label)
    assert len(torch.unique(label)) <= 2
def test_classification_serializers_multi_label():
    example_output = torch.tensor([-0.1, 0.2, 0.3])  # 3 classes
    labels = ['class_1', 'class_2', 'class_3']

    assert torch.allclose(
        torch.tensor(Logits(multi_label=True).serialize(example_output)),
        example_output)
    assert torch.allclose(
        torch.tensor(
            Probabilities(multi_label=True).serialize(example_output)),
        torch.sigmoid(example_output),
    )
    assert Classes(multi_label=True).serialize(example_output) == [1, 2]
    assert Labels(labels, multi_label=True).serialize(example_output) == [
        'class_2', 'class_3'
    ]
Beispiel #4
0
    def __init__(
        self,
        num_features: int,
        num_classes: int,
        embedding_sizes: List[Tuple[int, int]] = None,
        loss_fn: Callable = F.cross_entropy,
        optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
        scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
        scheduler_kwargs: Optional[Dict[str, Any]] = None,
        metrics: Union[Metric, Callable, Mapping, Sequence, None] = None,
        learning_rate: float = 1e-2,
        multi_label: bool = False,
        serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
        **tabnet_kwargs,
    ):
        self.save_hyperparameters()

        cat_dims, cat_emb_dim = zip(*embedding_sizes) if embedding_sizes else ([], [])
        model = TabNet(
            input_dim=num_features,
            output_dim=num_classes,
            cat_idxs=list(range(len(embedding_sizes))),
            cat_dims=list(cat_dims),
            cat_emb_dim=list(cat_emb_dim),
            **tabnet_kwargs,
        )

        super().__init__(
            model=model,
            loss_fn=loss_fn,
            optimizer=optimizer,
            optimizer_kwargs=optimizer_kwargs,
            scheduler=scheduler,
            scheduler_kwargs=scheduler_kwargs,
            metrics=metrics,
            learning_rate=learning_rate,
            multi_label=multi_label,
            serializer=serializer or Probabilities(),
        )

        self.save_hyperparameters()
Beispiel #5
0
    def __init__(
        self,
        num_features: int,
        num_classes: int,
        embedding_sizes: List[Tuple] = None,
        loss_fn: Callable = F.cross_entropy,
        optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
        metrics: List[Metric] = None,
        learning_rate: float = 1e-2,
        multi_label: bool = False,
        serializer: Optional[Union[Serializer, Mapping[str,
                                                       Serializer]]] = None,
        **tabnet_kwargs,
    ):
        if not _TABULAR_AVAILABLE:
            raise ModuleNotFoundError(
                "Please, pip install 'lightning-flash[tabular]'")

        self.save_hyperparameters()

        cat_dims, cat_emb_dim = zip(
            *embedding_sizes) if len(embedding_sizes) else ([], [])
        model = TabNet(input_dim=num_features,
                       output_dim=num_classes,
                       cat_idxs=list(range(len(embedding_sizes))),
                       cat_dims=list(cat_dims),
                       cat_emb_dim=list(cat_emb_dim),
                       **tabnet_kwargs)

        super().__init__(
            model=model,
            loss_fn=loss_fn,
            optimizer=optimizer,
            metrics=metrics,
            learning_rate=learning_rate,
            multi_label=multi_label,
            serializer=serializer or Probabilities(),
        )

        self.save_hyperparameters()
Beispiel #6
0
# limitations under the License.
from flash import Trainer
from flash.core.classification import Probabilities
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip",
              "data/")

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 3a. Predict what's on a few images! ants or bees?

model.serializer = Probabilities()
predictions = model.predict([
    "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
    "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
    "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
])
print(predictions)

# 3b. Or generate predictions with a whole folder!
datamodule = ImageClassificationData.from_folders(
    predict_folder="data/hymenoptera_data/predict/")

predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)