示例#1
0
def test_compare_random_converges(target, source):
    """Determine that the scores of random patches are in correct range.
    """
    matcher1 = FeatureMatcher(target, source)
    matcher1.compare_features_matrix(split=2)

    matcher2 = FeatureMatcher(target, source)
    for _ in range(500):
        matcher2.compare_features_random(split=2)
        missing = (
            (matcher1.repro_target.indices != matcher2.repro_target.indices).sum()
            + (matcher1.repro_sources.indices != matcher2.repro_sources.indices).sum()
        )
        if missing == 0:
            break

    assert (matcher1.repro_target.indices != matcher2.repro_target.indices).sum() <= 2
    assert pytest.approx(0.0, abs=1e-6) == torch.dist(
        matcher1.repro_target.scores, matcher2.repro_target.scores
    )

    assert matcher2.repro_sources.indices.min() != -1
    assert (matcher1.repro_sources.indices != matcher2.repro_sources.indices).sum() <= 2
    assert pytest.approx(0.0, abs=1e-6) == torch.dist(
        matcher1.repro_sources.scores, matcher2.repro_sources.scores
    )
示例#2
0
def test_scores_zero(content, style):
    """Scores must be zero if inputs vary on different dimensions.
    """
    content[:, 0], style[:, 1] = 0.0, 0.0
    matcher = FeatureMatcher(content, style)
    matcher.compare_features_matrix(split=2)
    assert pytest.approx(0.0) == matcher.repro_target.scores.max()
    assert pytest.approx(0.0) == matcher.repro_sources.scores.max()
示例#3
0
def test_scores_one(content, style):
    """Scores must be one if inputs only vary on one dimension.
    """
    content[:, 0], style[:, 0] = 0.0, 0.0
    matcher = FeatureMatcher(content, style)
    matcher.compare_features_matrix(split=2)
    assert pytest.approx(1.0) == matcher.repro_target.scores.min()
    assert pytest.approx(1.0) == matcher.repro_sources.scores.min()
示例#4
0
def test_scores_range_matrix(target, source):
    """Determine that the scores of random patches are in correct range.
    """
    matcher = FeatureMatcher(target, source)
    matcher.compare_features_matrix(split=2)
    assert matcher.repro_target.scores.min() >= 0.0
    assert matcher.repro_target.scores.max() <= 1.0
    assert matcher.repro_sources.scores.min() >= 0.0
    assert matcher.repro_sources.scores.max() <= 1.0
示例#5
0
def test_scores_identity(array):
    """The score of the identity operation with linear indices should be one.
    """
    matcher = FeatureMatcher(array, array)
    matcher.repro_target.from_linear(array.shape)
    matcher.repro_sources.from_linear(array.shape)
    matcher.compare_features_matrix(split=2)

    assert pytest.approx(1.0) == matcher.repro_target.scores.min()
    assert pytest.approx(1.0) == matcher.repro_sources.scores.min()
示例#6
0
def test_indices_symmetry_matrix(content, style):
    """The indices of the symmerical operation must be equal.
    """
    matcher1 = FeatureMatcher(content, style)
    matcher2 = FeatureMatcher(style, content)

    matcher1.compare_features_matrix(split=2)
    matcher2.compare_features_matrix(split=2)

    assert (matcher1.repro_target.indices != matcher2.repro_sources.indices).sum() <= 2
    assert (matcher1.repro_sources.indices != matcher2.repro_target.indices).sum() <= 2
示例#7
0
def test_scores_target_bias_matrix(array):
    matcher = FeatureMatcher(torch.cat([array, array], dim=2), array)

    matcher.repro_sources.biases[:, :, 11:] = 1.0
    matcher.repro_sources.scores.zero_()
    matcher.compare_features_matrix(split=2)
    assert (matcher.repro_sources.indices[:, 0] >= 11).all()

    matcher.repro_sources.biases[:, :, 11:] = 0.0
    matcher.repro_sources.biases[:, :, :11] = 1.0
    matcher.repro_sources.scores.zero_()
    matcher.compare_features_matrix(split=2)
    assert (matcher.repro_sources.indices[:, 0] < 11).all()
示例#8
0
def test_scores_source_bias_matrix(array):
    matcher = FeatureMatcher(array, torch.cat([array, array], dim=2))

    matcher.repro_target.biases[:, :, 9:] = 1.0
    matcher.repro_target.scores.zero_()
    matcher.compare_features_matrix(split=2)
    assert (matcher.repro_target.indices[:, 0] >= 9).all()

    matcher.repro_target.biases[:, :, 9:] = 0.0
    matcher.repro_target.biases[:, :, :9] = 1.0
    matcher.repro_target.scores.zero_()
    matcher.compare_features_matrix(split=2)
    assert (matcher.repro_target.indices[:, 0] < 9).all()
示例#9
0
def test_indices_same_rotate(content, style):
    """The score of the identity operation with linear indices should be one.
    """
    matcher1 = FeatureMatcher(content, style)
    matcher1.compare_features_matrix(split=2)

    matcher2 = FeatureMatcher(content, style.permute(0, 1, 3, 2))
    matcher2.compare_features_matrix(split=2)

    assert (matcher1.repro_target.indices[:, 0] !=
            matcher2.repro_target.indices[:, 1]).sum() <= 1
    assert (matcher2.repro_target.indices[:, 1] !=
            matcher1.repro_target.indices[:, 0]).sum() <= 1
示例#10
0
def test_indices_same_split(content, style):
    """The score of the identity operation with linear indices should be one.
    """
    matcher = FeatureMatcher(content, style)
    matcher.compare_features_matrix(split=1)
    target_indices = matcher.repro_target.indices.clone()
    source_indices = matcher.repro_sources.indices.clone()

    for split in [2, 4, 8]:
        matcher.update_target(content)
        matcher.compare_features_matrix(split=split)

        assert (target_indices != matcher.repro_target.indices).sum() <= 2
        assert (source_indices != matcher.repro_sources.indices).sum() <= 2
示例#11
0
def test_scores_reconstruct(target, source):
    """Scores must be one if inputs only vary on one dimension.
    """
    matcher = FeatureMatcher(target, source)
    matcher.compare_features_matrix()

    recons_target = matcher.reconstruct_target()
    score = cosine_similarity_vector_1d(target, recons_target)
    assert pytest.approx(0.0, abs=1e-6) == abs(
        score.mean() - matcher.repro_target.scores.mean()
    )

    recons_source = matcher.reconstruct_source()
    score = cosine_similarity_vector_1d(source, recons_source)
    assert pytest.approx(0.0, abs=1e-6) == abs(
        score.mean() - matcher.repro_sources.scores.mean()
    )
示例#12
0
def test_nearest_neighbor_vs_matcher(content, style):
    """The score of the identity operation with linear indices should be one.
    """

    matcher = FeatureMatcher(content, style)
    matcher.compare_features_matrix(split=1)

    ida, idb = nearest_neighbors_1d(content.flatten(2),
                                    style.flatten(2),
                                    split=1)

    ima = (matcher.repro_target.indices[:, 0] * style.shape[2] +
           matcher.repro_target.indices[:, 1])
    assert (ima.flatten(1) != ida).sum() == 0

    imb = (matcher.repro_sources.indices[:, 0] * content.shape[2] +
           matcher.repro_sources.indices[:, 1])
    assert (imb.flatten(1) != idb).sum() == 0