示例#1
0
def test_content_transform_grayscale_image(subtests, content_image,
                                           impl_params, instance_norm):
    content_image = F.rgb_to_grayscale(content_image)
    edge_size = 16

    hyper_parameters = paper.hyper_parameters(impl_params=impl_params,
                                              instance_norm=instance_norm)
    hyper_parameters.content_transform.edge_size = edge_size

    content_transform = paper.content_transform(
        impl_params=impl_params,
        instance_norm=instance_norm,
        hyper_parameters=hyper_parameters,
    )
    if instance_norm:
        utils.make_reproducible()
    actual = content_transform(content_image)

    if impl_params:
        if instance_norm:
            # Since the transform involves an uncontrollable random component, we can't
            # do a reference test here
            return

        transform_image = F.resize(content_image, edge_size)
    else:
        transform = transforms.CenterCrop(edge_size)
        transform_image = transform(content_image)

    desired = transform_image.repeat(1, 3, 1, 1)

    ptu.assert_allclose(actual, desired)
示例#2
0
def test_style_loss(subtests, impl_params, instance_norm):
    hyper_parameters = paper.hyper_parameters(
        impl_params=impl_params, instance_norm=instance_norm).style_loss

    style_loss = paper.style_loss(
        impl_params=impl_params,
        instance_norm=instance_norm,
    )
    assert isinstance(style_loss, pystiche.loss.MultiLayerEncodingLoss)

    with subtests.test("losses"):
        assert all(
            isinstance(loss, paper.GramLoss) for loss in style_loss.children())

    layers, layer_weights = zip(*[(loss.encoder.layer, loss.score_weight)
                                  for loss in style_loss.children()])
    with subtests.test("layers"):
        assert layers == hyper_parameters.layers

    with subtests.test("layer_weights"):
        assert layer_weights == pytest.approx(hyper_parameters.layer_weights)

    with subtests.test("score_weight"):
        assert style_loss.score_weight == pytest.approx(
            hyper_parameters.score_weight)
示例#3
0
def test_dataset(subtests, mocker, impl_params, instance_norm):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params,
                                              instance_norm=instance_norm)
    mocker.patch(
        "pystiche_papers.ulyanov_et_al_2016._data.ImageFolderDataset._collect_image_files",
        return_value=[],
    )
    dataset = paper.dataset("root",
                            impl_params=impl_params,
                            instance_norm=instance_norm)

    assert isinstance(dataset, torch.utils.data.IterableDataset)

    with subtests.test("dataset"):
        assert isinstance(dataset.dataset, ImageFolderDataset)

    with subtests.test("content_transform"):
        assert isinstance(dataset.transform, type(paper.content_transform()))

    with subtests.test("min_size"):
        assert dataset.min_size == hyper_parameters.content_transform.edge_size

    with subtests.test("length"):
        assert (len(dataset) == hyper_parameters.num_batches *
                hyper_parameters.batch_size)
示例#4
0
def test_stylization_smoke(stylization, postprocessor_mocks, input_image):
    hyper_parameters = paper.hyper_parameters()
    edge_size = hyper_parameters.content_transform.edge_size
    _, _, output_image = stylization(input_image)
    ptu.assert_allclose(output_image,
                        F.resize(input_image, (edge_size, edge_size)) + 0.5,
                        rtol=1e-6)
def test_hyper_parameters_lr_scheduler(subtests, impl_params, instance_norm):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params,
                                              instance_norm=instance_norm)

    sub_params = "lr_scheduler"
    assert sub_params in hyper_parameters
    hyper_parameters = getattr(hyper_parameters, sub_params)

    with subtests.test("lr_decay"):
        assert hyper_parameters.lr_decay == 0.8 if impl_params else 0.7
def test_hyper_parameters_optimizer(subtests, impl_params, instance_norm):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params,
                                              instance_norm=instance_norm)

    sub_params = "optimizer"
    assert sub_params in hyper_parameters
    hyper_parameters = getattr(hyper_parameters, sub_params)

    with subtests.test("lr"):
        assert hyper_parameters.lr == 1e-3 if impl_params and instance_norm else 1e-1
def test_hyper_parameters_content_transform(subtests, impl_params,
                                            instance_norm):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params,
                                              instance_norm=instance_norm)

    sub_params = "content_transform"
    assert sub_params in hyper_parameters
    hyper_parameters = getattr(hyper_parameters, sub_params)

    with subtests.test("edge_size"):
        assert hyper_parameters.edge_size == 256
def test_hyper_parameters_num_images(impl_params, instance_norm):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params,
                                              instance_norm=instance_norm)

    num_batches_per_epoch = hyper_parameters.num_batches
    num_epochs = hyper_parameters.num_epochs

    num_images = num_batches_per_epoch * num_epochs

    assert num_images == (
        50_000 if instance_norm else 3_000) if impl_params else 2_000
示例#9
0
def test_style_transform(subtests, impl_params, instance_norm):
    hyper_parameters = paper.hyper_parameters(
        impl_params=impl_params, instance_norm=instance_norm).style_transform

    style_transform = paper.style_transform(impl_params=impl_params,
                                            instance_norm=instance_norm)

    with subtests.test("edge_size"):
        assert style_transform.edge_size == hyper_parameters.edge_size

    with subtests.test("interpolation"):
        assert style_transform.interpolation == hyper_parameters.interpolation
示例#10
0
def test_image_loader(subtests, impl_params, instance_norm):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params,
                                              instance_norm=instance_norm)

    dataset = ()
    image_loader = paper.image_loader(dataset,
                                      impl_params=impl_params,
                                      instance_norm=instance_norm)

    assert isinstance(image_loader, DataLoader)

    with subtests.test("batch_size"):
        assert image_loader.batch_size == hyper_parameters.batch_size
示例#11
0
def test_hyper_parameters_content_loss(subtests, impl_params, instance_norm):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params,
                                              instance_norm=instance_norm)

    sub_params = "content_loss"
    assert sub_params in hyper_parameters
    hyper_parameters = getattr(hyper_parameters, sub_params)

    with subtests.test("layer"):
        assert hyper_parameters.layer == "relu4_2"

    with subtests.test("score_weight"):
        assert hyper_parameters.score_weight == pytest.approx(
            6e-1 if impl_params and not instance_norm else 1e0)
示例#12
0
def test_hyper_parameters_lr_decay_delay(impl_params, instance_norm):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params,
                                              instance_norm=instance_norm)

    num_batches = hyper_parameters.num_batches
    delay = hyper_parameters.lr_scheduler.delay

    num_batches_before_first_decay = num_batches * (delay + 1)
    num_batches_between_decays = num_batches

    assert (num_batches_before_first_decay == (2_000 if instance_norm else 300)
            if impl_params else 1000)
    assert (num_batches_between_decays == (2_000 if instance_norm else 300)
            if impl_params else 200)
示例#13
0
def test_hyper_parameters_style_transform(subtests, impl_params,
                                          instance_norm):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params,
                                              instance_norm=instance_norm)

    sub_params = "style_transform"
    assert sub_params in hyper_parameters
    hyper_parameters = getattr(hyper_parameters, sub_params)

    with subtests.test("edge_size"):
        assert hyper_parameters.edge_size == 256

    with subtests.test("interpolation"):
        assert hyper_parameters.interpolation == (
            InterpolationMode.BICUBIC
            if impl_params and instance_norm else InterpolationMode.BILINEAR)
示例#14
0
def test_content_loss(subtests, impl_params, instance_norm):
    hyper_parameters = paper.hyper_parameters(
        impl_params=impl_params, instance_norm=instance_norm).content_loss

    content_loss = paper.content_loss(
        impl_params=impl_params,
        instance_norm=instance_norm,
    )
    assert isinstance(content_loss, pystiche.loss.FeatureReconstructionLoss)

    with subtests.test("layer"):
        assert content_loss.encoder.layer == hyper_parameters.layer

    with subtests.test("score_weight"):
        assert content_loss.score_weight == pytest.approx(
            hyper_parameters.score_weight)
示例#15
0
def test_ulyanov_et_al_2016_lr_scheduler(subtests, impl_params, instance_norm):
    transformer = nn.Conv2d(3, 3, 1)
    optimizer = paper.optimizer(transformer)

    hyper_parameters = paper.hyper_parameters(
        impl_params=impl_params, instance_norm=instance_norm).lr_scheduler

    lr_scheduler = paper.lr_scheduler(optimizer,
                                      impl_params=impl_params,
                                      instance_norm=instance_norm)

    assert isinstance(lr_scheduler, paper.DelayedExponentialLR)

    with subtests.test("lr_decay"):
        assert lr_scheduler.gamma == hyper_parameters.lr_decay

    with subtests.test("delay"):
        assert lr_scheduler.delay == hyper_parameters.delay
示例#16
0
def test_hyper_parameters_style_loss(subtests, impl_params, instance_norm):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params,
                                              instance_norm=instance_norm)

    sub_params = "style_loss"
    assert sub_params in hyper_parameters
    hyper_parameters = getattr(hyper_parameters, sub_params)

    with subtests.test("layer"):
        assert hyper_parameters.layers == (("relu1_1", "relu2_1", "relu3_1",
                                            "relu4_1") if impl_params
                                           and instance_norm else
                                           ("relu1_1", "relu2_1", "relu3_1",
                                            "relu4_1", "relu5_1"))

    with subtests.test("layer_weights"):
        assert hyper_parameters.layer_weights == pytest.approx(
            [1e3 if impl_params and not instance_norm else 1e0] *
            len(hyper_parameters.layers))

    with subtests.test("score_weight"):
        assert hyper_parameters.score_weight == pytest.approx(1e0)
示例#17
0
def test_ulyanov_et_al_2016_optimizer(subtests, impl_params, instance_norm):
    transformer = nn.Conv2d(3, 3, 1)
    params = tuple(transformer.parameters())

    hyper_parameters = paper.hyper_parameters(
        impl_params=impl_params, instance_norm=instance_norm).optimizer

    optimizer = paper.optimizer(transformer,
                                impl_params=impl_params,
                                instance_norm=instance_norm)

    assert isinstance(optimizer, optim.Adam)
    assert len(optimizer.param_groups) == 1

    param_group = optimizer.param_groups[0]

    with subtests.test(msg="optimization params"):
        assert len(param_group["params"]) == len(params)
        for actual, desired in zip(param_group["params"], params):
            assert actual is desired

    with subtests.test(msg="optimizer properties"):
        assert param_group["lr"] == ptu.approx(hyper_parameters.lr)
示例#18
0
def test_hyper_parameters_num_batches(impl_params, instance_norm):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params,
                                              instance_norm=instance_norm)

    assert hyper_parameters.num_batches == ((2000 if instance_norm else 300)
                                            if impl_params else 200)
示例#19
0
def test_hyper_parameters_batch_size(impl_params, instance_norm):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params,
                                              instance_norm=instance_norm)

    assert hyper_parameters.batch_size == ((1 if instance_norm else 4)
                                           if impl_params else 16)
示例#20
0
def test_hyper_parameters_smoke():
    hyper_parameters = paper.hyper_parameters()
    assert isinstance(hyper_parameters, HyperParameters)