def test_dtype(self): oh_ary = onehot(ary=np.array([0, 1, 2]), dtype=np.int32) expect = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) self.assertTrue(np.array_equal(oh_ary, expect)) self.assertTrue(oh_ary.dtype == np.int32)
def test_skiplabel(self): oh_ary = onehot(ary=np.array([0, 1, 2, 3, 5])) expect = np.array([[1., 0., 0., 0., 0., 0.], [0., 1., 0., 0., 0., 0.], [0., 0., 1., 0., 0., 0.], [0., 0., 0., 1., 0., 0.], [0., 0., 0., 0., 0., 1.]]) self.assertTrue(np.array_equal(oh_ary, expect))
def test_defaults(self): oh_ary = onehot(ary=np.array([0, 1, 2, 3])) expect = np.array([[1., 0., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]]) self.assertTrue(np.array_equal(oh_ary, expect)) self.assertTrue(oh_ary.dtype == np.float32)
def test_n_classes(self): oh_ary = onehot(ary=np.array([0, 1, 2]), n_classes=5) expect = np.array([[1., 0., 0., 0., 0.], [0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.]]) self.assertTrue(np.array_equal(oh_ary, expect))