Beispiel #1
0
def test_compute_feats(device: str) -> None:
    dataset = TestDataset()
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=3,
        num_workers=2,
    )
    fid = FID()
    model = InceptionV3()
    fid.compute_feats(loader, model, device=device)
Beispiel #2
0
def test_inception_input_range(input_range, normalize_input, expectation) -> None:
    with expectation:
        dataset = TestDataset(input_range)
        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=3,
            num_workers=2,
        )
        fid = FID()
        model = InceptionV3(normalize_input=normalize_input)
        fid.compute_feats(loader, model, device='cpu')
Beispiel #3
0
def test_compute_feats_cuda() -> None:
    try:
        dataset = TestDataset()
        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=3,
            num_workers=2,
        )
        metric = MSID()
        model = InceptionV3()
        metric.compute_feats(loader, model, device='cuda')
    except Exception as e:
        pytest.fail(f"Unexpected error occurred: {e}")
Beispiel #4
0
    def _compute_feats(self,
                       loader: torch.utils.data.DataLoader,
                       feature_extractor: torch.nn.Module = None,
                       device: str = 'cuda') -> torch.Tensor:
        r"""Generate low-dimensional image desciptors

        Args:
            loader: Should return dict with key `images` in it
            feature_extractor: model used to generate image features, if None use `InceptionNetV3` model.
                Model should return a list with features from one of the network layers.
            out_features: size of `feature_extractor` output
            device: Device on which to compute inference of the model
        """

        if feature_extractor is None:
            print(
                'WARNING: default feature extractor (InceptionNet V2) is used.'
            )
            feature_extractor = InceptionV3()
        else:
            assert isinstance(feature_extractor, torch.nn.Module), \
                f"Feature extractor must be PyTorch module. Got {type(feature_extractor)}"
        feature_extractor.to(device)
        feature_extractor.eval()

        total_feats = []
        for batch in loader:
            images = batch['images']
            N = images.shape[0]
            images = images.float().to(device)

            # Get features
            features = feature_extractor(images)
            # TODO(jamil 26.03.20): Add support for more than one feature map
            assert len(features) == 1, \
                f"feature_encoder must return list with features from one layer. Got {len(features)}"
            total_feats.append(features[0].view(N, -1))

        return torch.cat(total_feats, dim=0)
def test_content_loss_supports_custom_extractor(x, y, device: str) -> None:
    loss = ContentLoss(feature_extractor=InceptionV3().blocks,
                       layers=['0', '1'],
                       weights=[0.5, 0.5])
    loss(x, y)
def test_content_loss_raises_if_wrong_reduction(x, y) -> None:
    for mode in ['mean', 'sum', 'none']:
        ContentLoss(reduction=mode)(x, y)

    for mode in [None, 'n', 2]:
        with pytest.raises(ValueError):
            ContentLoss(reduction=mode)(x, y)


@pytest.mark.parametrize(
    "model,expectation",
    [
        ('vgg16', raise_nothing()),
        ('vgg19', raise_nothing()),
        (InceptionV3(), raise_nothing()),
        (None, pytest.raises(ValueError)),
        ('random_encoder', pytest.raises(ValueError)),
    ],
)
def test_content_loss_raises_if_wrong_extractor(x, y, model: Union[str,
                                                                   Callable],
                                                expectation: Any) -> None:
    with expectation:
        ContentLoss(feature_extractor=model)


@pytest.mark.parametrize(
    "model",
    ['vgg16', InceptionV3()],
)
Beispiel #7
0
def test_content_loss_supports_custom_extractor(prediction: torch.Tensor,
                                                target: torch.Tensor,
                                                device: str) -> None:
    loss = ContentLoss(feature_extractor=InceptionV3().blocks,
                       layers=['0', '1'])
    loss(prediction, target)