def test_expand_l_shape_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.ones((2, 2), dtype=flow.float32, requires_grad=True) x2 = flow.ones((2, 0), dtype=flow.float32, requires_grad=True) y = flow.expand(x1, x2.shape) test_case.assertTrue( "The expanded size of the tensor" in str(context.exception))
def test_expand_dim_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x1 = flow.ones((2, 1), dtype=flow.float32, requires_grad=True) x2 = flow.ones((2), dtype=flow.float32, requires_grad=True) y = flow.expand(x1, x2.shape) test_case.assertTrue( "be greater or equal to the number of dimensions in the tensor" in str(context.exception))
def _expand_as(input, other): return flow.expand(input, *other.size())
def _expand(self, *size): return flow.expand(self, *size)