Exemplo n.º 1
0
def test_compare_random_converges(target, source):
    """Determine that the scores of random patches are in correct range.
    """
    matcher1 = FeatureMatcher(target, source)
    matcher1.compare_features_matrix(split=2)

    matcher2 = FeatureMatcher(target, source)
    for _ in range(500):
        matcher2.compare_features_random(split=2)
        missing = (
            (matcher1.repro_target.indices != matcher2.repro_target.indices).sum()
            + (matcher1.repro_sources.indices != matcher2.repro_sources.indices).sum()
        )
        if missing == 0:
            break

    assert (matcher1.repro_target.indices != matcher2.repro_target.indices).sum() <= 2
    assert pytest.approx(0.0, abs=1e-6) == torch.dist(
        matcher1.repro_target.scores, matcher2.repro_target.scores
    )

    assert matcher2.repro_sources.indices.min() != -1
    assert (matcher1.repro_sources.indices != matcher2.repro_sources.indices).sum() <= 2
    assert pytest.approx(0.0, abs=1e-6) == torch.dist(
        matcher1.repro_sources.scores, matcher2.repro_sources.scores
    )
Exemplo n.º 2
0
def test_scores_improve(content, style):
    """Scores must be one if inputs only vary on one dimension.
    """
    matcher = FeatureMatcher(content, style)
    matcher.compare_features_identity()
    before = matcher.repro_target.scores.sum()
    matcher.compare_features_random(times=1)
    after = matcher.repro_target.scores.sum()
    event("equal? %i" % int(after == before))
    assert after >= before
Exemplo n.º 3
0
def test_scores_range_random(target, source):
    """Determine that the scores of random patches are in correct range.
    """
    matcher = FeatureMatcher(target, source)
    matcher.compare_features_random(split=1)

    assert matcher.repro_sources.indices.min() != -1
    assert matcher.repro_target.scores.min() >= 0.0
    assert matcher.repro_target.scores.max() <= 1.0

    assert matcher.repro_sources.scores.max() >= 0.0
    assert matcher.repro_sources.indices.max() != -1
Exemplo n.º 4
0
def test_scores_target_bias_random(array):
    matcher = FeatureMatcher(torch.cat([array, array], dim=2), array)

    matcher.repro_sources.biases[:, :, 12:] = 1.0
    matcher.repro_sources.scores.fill_(-1.0)
    for _ in range(10):
        matcher.compare_features_random(split=2)
    assert (matcher.repro_sources.indices[:, 0] >= 12).all()

    matcher.repro_sources.biases[:, :, 12:] = 0.0
    matcher.repro_sources.biases[:, :, :12] = 1.0
    matcher.repro_sources.scores.fill_(-1.0)
    for _ in range(10):
        matcher.compare_features_random(split=2)
    assert (matcher.repro_sources.indices[:, 0] < 12).all()
Exemplo n.º 5
0
def test_scores_source_bias_random(array):
    matcher = FeatureMatcher(array, torch.cat([array, array], dim=2))

    matcher.repro_target.biases[:, :, 8:] = 1.0
    matcher.repro_target.scores.fill_(-1.0)
    for _ in range(10):
        matcher.compare_features_random(split=2)
    assert (matcher.repro_target.indices[:, 0] >= 8).all()

    matcher.repro_target.biases[:, :, 8:] = 0.0
    matcher.repro_target.biases[:, :, :8] = 1.0
    matcher.repro_target.scores.fill_(-1.0)
    for _ in range(10):
        matcher.compare_features_random(split=2)
    assert (matcher.repro_target.indices[:, 0] < 8).all()
Exemplo n.º 6
0
def test_indices_symmetry_random(content, style):
    """The indices of the symmerical operation must be the same.
    """
    matcher1 = FeatureMatcher(content, style)
    matcher2 = FeatureMatcher(style, content)

    for _ in range(25):
        matcher1.compare_features_random()
        matcher2.compare_features_random()

        missing = sum(
            [
                (matcher1.repro_target.indices != matcher2.repro_sources.indices).sum(),
                (matcher1.repro_sources.indices != matcher2.repro_target.indices).sum(),
            ]
        )
        if missing == 0:
            break

    assert (matcher1.repro_target.indices != matcher2.repro_sources.indices).sum() <= 2
    assert (matcher1.repro_sources.indices != matcher2.repro_target.indices).sum() <= 2