def test_bmm(self): t1 = TensorBase(np.array([[[3, 1]], [[1, 2]]])) t2 = TensorBase(np.array([[[1], [3]], [[4], [8]]])) out = syft.bmm(t1, t2) test_result = np.array([[[6]], [[20]]]) self.assertTrue(np.array_equal(out.data, test_result))
def test_bmm_for_correct_size_output(self): t1 = TensorBase(np.random.rand(4, 3, 2)) t2 = TensorBase(np.random.rand(4, 2, 1)) out = syft.bmm(t1, t2) self.assertTupleEqual(out.shape(), (4, 3, 1))