예제 #1
0
def test_hyper_parameters_guided_style_loss(subtests, impl_params):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params)

    sub_params = "guided_style_loss"
    assert sub_params in hyper_parameters
    parameters = getattr(hyper_parameters, sub_params)

    layers, num_channels = zip(
        ("relu1_1" if impl_params else "conv1_1", 64),
        ("relu2_1" if impl_params else "conv2_1", 128),
        ("relu3_1" if impl_params else "conv3_1", 256),
        ("relu4_1" if impl_params else "conv4_1", 512),
        ("relu5_1" if impl_params else "conv5_1", 512),
    )
    layer_weights = [1 / n**2 for n in num_channels]

    with subtests.test("layers"):
        assert parameters.layers == layers

    with subtests.test("layer_weights"):
        assert parameters.layer_weights == pytest.approx(layer_weights)

    with subtests.test("region_weights"):
        assert parameters.region_weights == "sum"

    with subtests.test("score_weight"):
        assert parameters.score_weight == pytest.approx(1e3)
예제 #2
0
def test_content_loss(subtests):
    content_loss = paper.content_loss()
    assert isinstance(content_loss, pystiche.loss.FeatureReconstructionLoss)

    hyper_parameters = paper.hyper_parameters().content_loss

    with subtests.test("layer"):
        assert content_loss.encoder.layer == hyper_parameters.layer

    with subtests.test("score_weight"):
        assert content_loss.score_weight == pytest.approx(hyper_parameters.score_weight)
예제 #3
0
def test_hyper_parameters_content_loss(subtests, impl_params):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params)

    sub_params = "content_loss"
    assert sub_params in hyper_parameters
    hyper_parameters = getattr(hyper_parameters, sub_params)

    with subtests.test("layer"):
        assert hyper_parameters.layer == "relu4_2" if impl_params else "conv4_2"

    with subtests.test("score_weight"):
        assert hyper_parameters.score_weight == pytest.approx(1e0)
예제 #4
0
def test_image_pyramid(subtests):
    image_pyramid = paper.image_pyramid()
    assert isinstance(image_pyramid, pyramid.ImagePyramid)

    levels = tuple(iter(image_pyramid))
    assert len(levels) == 2

    edge_sizes, num_steps = zip(*[(level.edge_size, level.num_steps)
                                  for level in levels])
    hyper_parameters = paper.hyper_parameters().image_pyramid

    with subtests.test("edge_sizes"):
        assert edge_sizes == hyper_parameters.edge_sizes

    with subtests.test("num_steps"):
        assert num_steps == hyper_parameters.num_steps
예제 #5
0
def test_hyper_parameters_image_pyramid(subtests, impl_params):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params)

    sub_params = "image_pyramid"
    assert sub_params in hyper_parameters
    hyper_parameters = getattr(hyper_parameters, sub_params)

    with subtests.test("edge_sizes"):
        assert hyper_parameters.edge_sizes == (512 if impl_params else 500,
                                               1024)

    with subtests.test("num_steps"):
        if impl_params:
            assert hyper_parameters.num_steps == (500, 200)
        else:
            assert len(hyper_parameters.num_steps) == 2
            ratio = hyper_parameters.num_steps[0] / hyper_parameters.num_steps[
                1]
            assert ratio == pytest.approx(2.5)
예제 #6
0
def test_style_loss(subtests):
    style_loss = paper.style_loss()
    assert isinstance(style_loss, pystiche.loss.MultiLayerEncodingLoss)

    hyper_parameters = paper.hyper_parameters().style_loss

    with subtests.test("losses"):
        assert all(
            isinstance(loss, pystiche.loss.GramLoss) for loss in style_loss.children()
        )

    layers, layer_weights = zip(
        *[(loss.encoder.layer, loss.score_weight) for loss in style_loss.children()]
    )
    with subtests.test("layers"):
        assert layers == hyper_parameters.layers

    with subtests.test("layer_weights"):
        assert layer_weights == pytest.approx(hyper_parameters.layer_weights)

    with subtests.test("score_weight"):
        assert style_loss.score_weight == pytest.approx(hyper_parameters.score_weight)
예제 #7
0
def test_hyper_parameters_smoke(impl_params):
    hyper_parameters = paper.hyper_parameters(impl_params=impl_params)
    assert isinstance(hyper_parameters, HyperParameters)