def test_compare_random_converges(target, source): """Determine that the scores of random patches are in correct range. """ matcher1 = FeatureMatcher(target, source) matcher1.compare_features_matrix(split=2) matcher2 = FeatureMatcher(target, source) for _ in range(500): matcher2.compare_features_random(split=2) missing = ( (matcher1.repro_target.indices != matcher2.repro_target.indices).sum() + (matcher1.repro_sources.indices != matcher2.repro_sources.indices).sum() ) if missing == 0: break assert (matcher1.repro_target.indices != matcher2.repro_target.indices).sum() <= 2 assert pytest.approx(0.0, abs=1e-6) == torch.dist( matcher1.repro_target.scores, matcher2.repro_target.scores ) assert matcher2.repro_sources.indices.min() != -1 assert (matcher1.repro_sources.indices != matcher2.repro_sources.indices).sum() <= 2 assert pytest.approx(0.0, abs=1e-6) == torch.dist( matcher1.repro_sources.scores, matcher2.repro_sources.scores )
def test_scores_zero(content, style): """Scores must be zero if inputs vary on different dimensions. """ content[:, 0], style[:, 1] = 0.0, 0.0 matcher = FeatureMatcher(content, style) matcher.compare_features_matrix(split=2) assert pytest.approx(0.0) == matcher.repro_target.scores.max() assert pytest.approx(0.0) == matcher.repro_sources.scores.max()
def test_scores_one(content, style): """Scores must be one if inputs only vary on one dimension. """ content[:, 0], style[:, 0] = 0.0, 0.0 matcher = FeatureMatcher(content, style) matcher.compare_features_matrix(split=2) assert pytest.approx(1.0) == matcher.repro_target.scores.min() assert pytest.approx(1.0) == matcher.repro_sources.scores.min()
def test_scores_range_matrix(target, source): """Determine that the scores of random patches are in correct range. """ matcher = FeatureMatcher(target, source) matcher.compare_features_matrix(split=2) assert matcher.repro_target.scores.min() >= 0.0 assert matcher.repro_target.scores.max() <= 1.0 assert matcher.repro_sources.scores.min() >= 0.0 assert matcher.repro_sources.scores.max() <= 1.0
def test_scores_identity(array): """The score of the identity operation with linear indices should be one. """ matcher = FeatureMatcher(array, array) matcher.repro_target.from_linear(array.shape) matcher.repro_sources.from_linear(array.shape) matcher.compare_features_matrix(split=2) assert pytest.approx(1.0) == matcher.repro_target.scores.min() assert pytest.approx(1.0) == matcher.repro_sources.scores.min()
def test_indices_symmetry_matrix(content, style): """The indices of the symmerical operation must be equal. """ matcher1 = FeatureMatcher(content, style) matcher2 = FeatureMatcher(style, content) matcher1.compare_features_matrix(split=2) matcher2.compare_features_matrix(split=2) assert (matcher1.repro_target.indices != matcher2.repro_sources.indices).sum() <= 2 assert (matcher1.repro_sources.indices != matcher2.repro_target.indices).sum() <= 2
def test_scores_target_bias_matrix(array): matcher = FeatureMatcher(torch.cat([array, array], dim=2), array) matcher.repro_sources.biases[:, :, 11:] = 1.0 matcher.repro_sources.scores.zero_() matcher.compare_features_matrix(split=2) assert (matcher.repro_sources.indices[:, 0] >= 11).all() matcher.repro_sources.biases[:, :, 11:] = 0.0 matcher.repro_sources.biases[:, :, :11] = 1.0 matcher.repro_sources.scores.zero_() matcher.compare_features_matrix(split=2) assert (matcher.repro_sources.indices[:, 0] < 11).all()
def test_scores_source_bias_matrix(array): matcher = FeatureMatcher(array, torch.cat([array, array], dim=2)) matcher.repro_target.biases[:, :, 9:] = 1.0 matcher.repro_target.scores.zero_() matcher.compare_features_matrix(split=2) assert (matcher.repro_target.indices[:, 0] >= 9).all() matcher.repro_target.biases[:, :, 9:] = 0.0 matcher.repro_target.biases[:, :, :9] = 1.0 matcher.repro_target.scores.zero_() matcher.compare_features_matrix(split=2) assert (matcher.repro_target.indices[:, 0] < 9).all()
def test_indices_same_rotate(content, style): """The score of the identity operation with linear indices should be one. """ matcher1 = FeatureMatcher(content, style) matcher1.compare_features_matrix(split=2) matcher2 = FeatureMatcher(content, style.permute(0, 1, 3, 2)) matcher2.compare_features_matrix(split=2) assert (matcher1.repro_target.indices[:, 0] != matcher2.repro_target.indices[:, 1]).sum() <= 1 assert (matcher2.repro_target.indices[:, 1] != matcher1.repro_target.indices[:, 0]).sum() <= 1
def test_indices_same_split(content, style): """The score of the identity operation with linear indices should be one. """ matcher = FeatureMatcher(content, style) matcher.compare_features_matrix(split=1) target_indices = matcher.repro_target.indices.clone() source_indices = matcher.repro_sources.indices.clone() for split in [2, 4, 8]: matcher.update_target(content) matcher.compare_features_matrix(split=split) assert (target_indices != matcher.repro_target.indices).sum() <= 2 assert (source_indices != matcher.repro_sources.indices).sum() <= 2
def test_scores_reconstruct(target, source): """Scores must be one if inputs only vary on one dimension. """ matcher = FeatureMatcher(target, source) matcher.compare_features_matrix() recons_target = matcher.reconstruct_target() score = cosine_similarity_vector_1d(target, recons_target) assert pytest.approx(0.0, abs=1e-6) == abs( score.mean() - matcher.repro_target.scores.mean() ) recons_source = matcher.reconstruct_source() score = cosine_similarity_vector_1d(source, recons_source) assert pytest.approx(0.0, abs=1e-6) == abs( score.mean() - matcher.repro_sources.scores.mean() )
def test_nearest_neighbor_vs_matcher(content, style): """The score of the identity operation with linear indices should be one. """ matcher = FeatureMatcher(content, style) matcher.compare_features_matrix(split=1) ida, idb = nearest_neighbors_1d(content.flatten(2), style.flatten(2), split=1) ima = (matcher.repro_target.indices[:, 0] * style.shape[2] + matcher.repro_target.indices[:, 1]) assert (ima.flatten(1) != ida).sum() == 0 imb = (matcher.repro_sources.indices[:, 0] * content.shape[2] + matcher.repro_sources.indices[:, 1]) assert (imb.flatten(1) != idb).sum() == 0