Example #1
0
 def test_prototypes1d_init_with_int_data(self):
     _ = prototypes.Prototypes1D(
         num_classes=2,
         prototypes_per_class=1,
         prototype_initializer="stratified_mean",
         data=[[[1], [0]], [1, 0]],
     )
Example #2
0
 def test_prototypes1d_init_one_hot_without_data(self):
     _ = prototypes.Prototypes1D(input_dim=1,
                                 nclasses=2,
                                 prototypes_per_class=1,
                                 prototype_initializer='stratified_mean',
                                 data=None,
                                 one_hot_labels=True)
Example #3
0
 def test_prototypes1d_inputdim_with_data(self):
     with self.assertRaises(ValueError):
         _ = prototypes.Prototypes1D(
             input_dim=2,
             nclasses=2,
             prototypes_per_class=1,
             prototype_initializer='stratified_mean',
             data=[[[1.], [0.]], [1, 0]])
Example #4
0
 def test_prototypes1d_init_with_int_dtype(self):
     with self.assertRaises(RuntimeError):
         _ = prototypes.Prototypes1D(
             nclasses=2,
             prototypes_per_class=1,
             prototype_initializer='stratified_mean',
             data=[[[1], [0]], [1, 0]],
             dtype=torch.int32)
Example #5
0
 def test_prototypes1d_proto_init_without_data(self):
     with self.assertWarns(UserWarning):
         _ = prototypes.Prototypes1D(
             input_dim=3,
             nclasses=2,
             prototypes_per_class=1,
             prototype_initializer='stratified_mean',
             data=None)
Example #6
0
 def test_prototypes1d_forward(self):
     p1 = prototypes.Prototypes1D(data=[self.x, self.y])
     protos, _ = p1()
     actual = protos.detach().numpy()
     desired = torch.ones(2, 3)
     mismatch = np.testing.assert_array_almost_equal(actual,
                                                     desired,
                                                     decimal=5)
     self.assertIsNone(mismatch)
Example #7
0
 def test_prototypes1d_init_with_pdist(self):
     p1 = prototypes.Prototypes1D(data=[self.x, self.y],
                                  prototype_distribution=[6, 9],
                                  prototype_initializer='zeros')
     protos = p1.prototypes
     actual = protos.detach().numpy()
     desired = torch.zeros(15, 3)
     mismatch = np.testing.assert_array_almost_equal(actual,
                                                     desired,
                                                     decimal=5)
     self.assertIsNone(mismatch)
Example #8
0
 def test_prototypes1d_nclasses_with_data(self):
     """Test ValueError raise if provided `nclasses` is not the same
     as the one computed from the provided `data`.
     """
     with self.assertRaises(ValueError):
         _ = prototypes.Prototypes1D(
             input_dim=1,
             nclasses=1,
             prototypes_per_class=1,
             prototype_initializer='stratified_mean',
             data=[[[1.], [2.]], [1, 2]])
Example #9
0
 def test_prototypes1d_init_with_ppc(self):
     p1 = prototypes.Prototypes1D(data=[self.x, self.y],
                                  prototypes_per_class=2,
                                  prototype_initializer="zeros")
     protos = p1.prototypes
     actual = protos.detach().numpy()
     desired = torch.zeros(4, 3)
     mismatch = np.testing.assert_array_almost_equal(actual,
                                                     desired,
                                                     decimal=5)
     self.assertIsNone(mismatch)
Example #10
0
 def test_prototypes1d_init_1d_y_data_one_hot_labels_true(self):
     """Test if ValueError is raised when `one_hot_labels` is set to `True`
     but the provided `data` does not contain one-hot encoded labels.
     """
     with self.assertRaises(ValueError):
         _ = prototypes.Prototypes1D(
             input_dim=1,
             nclasses=2,
             prototypes_per_class=1,
             prototype_initializer='stratified_mean',
             data=([[0.], [1.]], [0, 1]),
             one_hot_labels=True)
Example #11
0
 def test_prototypes1d_init_one_hot_labels_false(self):
     """Test if ValueError is raised when `one_hot_labels` is set to `False`
     but the provided `data` has one-hot encoded labels.
     """
     with self.assertRaises(ValueError):
         _ = prototypes.Prototypes1D(
             input_dim=1,
             nclasses=2,
             prototypes_per_class=1,
             prototype_initializer='stratified_mean',
             data=([[0.], [1.]], [[0, 1], [1, 0]]),
             one_hot_labels=False)
Example #12
0
 def test_prototypes1d_init_torch_pdist(self):
     pdist = torch.tensor([2, 2])
     p1 = prototypes.Prototypes1D(input_dim=3,
                                  prototype_distribution=pdist,
                                  prototype_initializer='zeros')
     protos = p1.prototypes
     actual = protos.detach().numpy()
     desired = torch.zeros(4, 3)
     mismatch = np.testing.assert_array_almost_equal(actual,
                                                     desired,
                                                     decimal=5)
     self.assertIsNone(mismatch)
Example #13
0
 def test_prototypes1d_init_without_pdist(self):
     p1 = prototypes.Prototypes1D(input_dim=6,
                                  nclasses=2,
                                  prototypes_per_class=4,
                                  prototype_initializer='ones')
     protos = p1.prototypes
     actual = protos.detach().numpy()
     desired = torch.ones(8, 6)
     mismatch = np.testing.assert_array_almost_equal(actual,
                                                     desired,
                                                     decimal=5)
     self.assertIsNone(mismatch)
Example #14
0
 def test_prototypes1d_init_one_hot_labels_true(self):
     """Test if ValueError is raised when `one_hot_labels` is set to `True`
     but the provided `data` contains 2D targets but
     does not contain one-hot encoded labels.
     """
     with self.assertRaises(ValueError):
         _ = prototypes.Prototypes1D(
             input_dim=1,
             num_classes=2,
             prototypes_per_class=1,
             prototype_initializer="stratified_mean",
             data=([[0.0], [1.0]], [[0], [1]]),
             one_hot_labels=True,
         )
Example #15
0
    def test_prototypes1d_func_initializer(self):
        def my_initializer(*args, **kwargs):
            return torch.full((2, 99), 99.0), torch.tensor([0, 1])

        p1 = prototypes.Prototypes1D(input_dim=99,
                                     nclasses=2,
                                     prototypes_per_class=1,
                                     prototype_initializer=my_initializer)
        protos = p1.prototypes
        actual = protos.detach().numpy()
        desired = 99 * torch.ones(2, 99)
        mismatch = np.testing.assert_array_almost_equal(actual,
                                                        desired,
                                                        decimal=5)
        self.assertIsNone(mismatch)
Example #16
0
 def test_prototypes1d_init_with_int_data(self):
     _ = prototypes.Prototypes1D(nclasses=1,
                                 prototypes_per_class=1,
                                 prototype_initializer='stratified_mean',
                                 data=[[[1]], [1]])
Example #17
0
 def test_prototypes1d_inputndim_with_data(self):
     with self.assertRaises(ValueError):
         _ = prototypes.Prototypes1D(input_dim=1,
                                     nclasses=1,
                                     prototypes_per_class=1,
                                     data=[[1.], [1]])
Example #18
0
 def test_prototypes1d_init_without_inputdim_with_data(self):
     _ = prototypes.Prototypes1D(nclasses=2,
                                 prototypes_per_class=1,
                                 prototype_initializer='stratified_mean',
                                 data=[[[1.], [0.]], [1, 0]])
Example #19
0
 def test_prototypes1d_init_with_nclasses_1(self):
     with self.assertWarns(UserWarning):
         _ = prototypes.Prototypes1D(nclasses=1, input_dim=1)
Example #20
0
 def test_prototypes1d_validate_extra_repr_not_empty(self):
     p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
     rep = p1.extra_repr()
     self.assertNotEqual(rep, '')
Example #21
0
 def test_prototypes1d_dist_validate(self):
     p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
     with self.assertWarns(UserWarning):
         _ = p1._validate_prototype_distribution()
Example #22
0
 def test_prototypes1d_init_without_nclasses(self):
     with self.assertRaises(NameError):
         _ = prototypes.Prototypes1D(input_dim=1)