Пример #1
0
 def test_triplet_diffs_toy(self):
     expected_n_basis = 10
     model = SCML_Supervised(n_basis=expected_n_basis)
     X = np.array([[0, 0], [1, 1], [2, 2], [3, 3]])
     triplets = np.array([[0, 1, 2], [0, 1, 3], [1, 0, 2], [1, 0, 3],
                          [2, 3, 1], [2, 3, 0], [3, 2, 1], [3, 2, 0]])
     basis, n_basis = model._generate_bases_dist_diff(triplets, X)
     # All points are along the same line, so the only possible basis will be
     # the vector along that line normalized.
     expected_basis = np.ones((expected_n_basis, 2)) / np.sqrt(2)
     assert n_basis == expected_n_basis
     np.testing.assert_allclose(basis, expected_basis)
Пример #2
0
  def test_triplet_diffs(self, n_samples, n_features, n_classes):
    X, y = make_classification(n_samples=n_samples, n_classes=n_classes,
                               n_features=n_features, n_informative=n_features,
                               n_redundant=0, n_repeated=0)
    X = StandardScaler().fit_transform(X)

    model = SCML_Supervised()
    constraints = Constraints(y)
    triplets = constraints.generate_knntriplets(X, model.k_genuine,
                                                model.k_impostor)
    basis, n_basis = model._generate_bases_dist_diff(triplets, X)

    expected_n_basis = n_features * 80
    assert n_basis == expected_n_basis
    assert basis.shape == (expected_n_basis, n_features)
Пример #3
0
  def test_triplet_diffs(self, n_samples, n_features, n_classes):
    """
    Test that the correct value of n_basis is being generated with
    different triplet constraints.
    """
    X, y = make_classification(n_samples=n_samples, n_classes=n_classes,
                               n_features=n_features, n_informative=n_features,
                               n_redundant=0, n_repeated=0)
    X = StandardScaler().fit_transform(X)
    model = SCML_Supervised(n_basis=None)  # Explicit n_basis=None
    constraints = Constraints(y)
    triplets = constraints.generate_knntriplets(X, model.k_genuine,
                                                model.k_impostor)

    msg = "As no value for `n_basis` was selected, "
    with pytest.warns(UserWarning) as raised_warning:
      basis, n_basis = model._generate_bases_dist_diff(triplets, X)
    assert msg in str(raised_warning[0].message)

    expected_n_basis = n_features * 80
    assert n_basis == expected_n_basis
    assert basis.shape == (expected_n_basis, n_features)