Esempio n. 1
0
    def test_unsqueeze(self):
        t1 = TensorBase(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
        for i in range(len(t1.data.shape)):
            out = syft.unsqueeze(t1, i)
            expected_shape = list(t1.data.shape)
            expected_shape.insert(i, 1)

            self.assertTrue(np.array_equal(out.data.shape, expected_shape))
Esempio n. 2
0
 def unsqueeze(self, dim):
     """
     Returns expanded Tensor. An additional dimension of size one is added
     to at index 'dim'.
     """
     return syft.unsqueeze(self.data, dim)