Example #1
0
 def test_tensordot_recurring_dim_runtime_error(test_case):
     with test_case.assertRaises(Exception) as context:
         a = flow.randn(1, 2, 3)
         b = flow.randn(1, 2, 3)
         flow.tensordot(a, b, dims=[[1, 1], [1, 1]])
     test_case.assertTrue("dim 1 appears multiple times in the list of dims"
                          in str(context.exception))
Example #2
0
 def test_tensordot_dims_different_length_runtime_error(test_case):
     with test_case.assertRaises(Exception) as context:
         a = flow.randn(1, 2, 3)
         b = flow.randn(1, 2, 3)
         flow.tensordot(a, b, dims=[[1], [1, 2]])
     test_case.assertTrue("both dimension lists should have same length" in
                          str(context.exception))
Example #3
0
 def test_tensordot_neg_dims_runtime_error(test_case):
     with test_case.assertRaises(Exception) as context:
         a = flow.randn(1, 2, 3)
         b = flow.randn(1, 2, 3)
         flow.tensordot(a, b, dims=-1)
     test_case.assertTrue("tensordot expects dims >= 0, but got dims=-1" in
                          str(context.exception))
Example #4
0
 def test_tensordot_unmatch_dims_runtime_error(test_case):
     with test_case.assertRaises(Exception) as context:
         a = flow.randn(1, 2, 3)
         b = flow.randn(1, 2, 3)
         flow.tensordot(a, b, dims=[[1], [2]])
     test_case.assertTrue(
         "contracted dimensions need to match, but first has size 2 in dim 1 and second has size 3 in dim 2"
         in str(context.exception))
Example #5
0
 def test_tensordot_out_of_range_dims_runtime_error(test_case):
     with test_case.assertRaises(Exception) as context:
         a = flow.randn(1, 2, 3)
         b = flow.randn(1, 2, 3)
         flow.tensordot(a, b, dims=[[3], [2]])
     test_case.assertTrue(
         "Dimension out of range (expected to be in range of [-3, 2], but got 3)"
         in str(context.exception))
Example #6
0
 def test_tensordot_too_large_int_dims_runtime_error(test_case):
     with test_case.assertRaises(Exception) as context:
         a = flow.randn(1, 2, 3)
         b = flow.randn(1, 2, 3)
         flow.tensordot(a, b, dims=100)
     test_case.assertTrue(
         "tensordot expects dims <= a.ndim which is 3, but got 100" in str(
             context.exception))
Example #7
0
        def _test_tensor_dim(test_case, device):
            np_dim = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.int)
            flow_dim = flow.tensor(np_dim).to(device)
            torch_dim = torch.tensor(np_dim).to(device)

            np_random_array = np.random.randn(2, 3, 4, 5)
            flow_tensor = flow.tensor(np_random_array).to(device)
            torch_tensor = torch.tensor(np_random_array).to(device)

            flow_result = flow.tensordot(flow_tensor,
                                         flow_tensor,
                                         dims=flow_dim)
            torch_result = torch.tensordot(torch_tensor,
                                           torch_tensor,
                                           dims=torch_dim)
            test_case.assertTrue(
                np.allclose(
                    flow_result.numpy(),
                    torch_result.cpu().numpy(),
                    rtol=0.0001,
                    atol=0.0001,
                ))