Exemplo n.º 1
0
    def test_invalid_axes(self):
        # pylint: disable=not-context-manager
        with tf.compat.v1.Graph().as_default():
            a = [[1, 2], [3, 4]]
            b = [[1, 2], [3, 4]]
            # Invalid static axes.
            for axes_value in -1, 3, [1], [[1]], [[1], [0, 1]]:
                with self.assertRaises(ValueError):
                    tensordot2.tensordot(tf, a, b, axes_value)

            with self.assertRaises(IndexError):
                tensordot2.tensordot(tf, a, b, [[0], [7]])

            # Invalid dynamic axes.
            a_ph = tf.compat.v1.placeholder(tf.float32)
            b_ph = tf.compat.v1.placeholder(tf.float32)
            axes_ph = tf.compat.v1.placeholder(tf.int32)
            output = tensordot2.tensordot(tf, a_ph, b_ph, axes_ph)
            # Note: We don't support scalar Tensor values for axes.
            for axes_value in 1, [1], [0, 1], [[1]], [[0, 1]], [[0], [7]]:
                with self.cached_session() as sess:
                    with self.assertRaises(tf.errors.InvalidArgumentError):
                        _ = sess.run([output],
                                     feed_dict={
                                         a_ph: a,
                                         b_ph: b,
                                         axes_ph: axes_value
                                     })
Exemplo n.º 2
0
 def test_partial_shape_inference(self):
     # pylint: disable=not-context-manager
     with tf.compat.v1.Graph().as_default():
         for axes in ([1], [0]), 1:
             a = tf.compat.v1.placeholder(tf.float32)
             b = tf.compat.v1.placeholder(tf.float32)
             output = tensordot2.tensordot(tf, a, b, axes)
             self.assertEqual(output.get_shape().ndims, None)
             a.set_shape([None, 2])
             b.set_shape([2, 3])
             output = tensordot2.tensordot(tf, a, b, axes)
             output_shape = output.get_shape()
             self.assertEqual(output_shape.ndims, 2)
             output_shape = output_shape.as_list()
             self.assertEqual(output_shape[0], None)
             self.assertEqual(output_shape[1], 3)
             a = tf.compat.v1.placeholder(tf.float32)
             b = tf.compat.v1.placeholder(tf.float32)
             a.set_shape([2, 2])
             b.set_shape([2, None])
             output = tensordot2.tensordot(tf, a, b, axes)
             output_shape = output.get_shape()
             self.assertEqual(output_shape.ndims, 2)
             output_shape = output_shape.as_list()
             self.assertEqual(output_shape[0], 2)
             self.assertEqual(output_shape[1], None)
Exemplo n.º 3
0
    def test_valid_axis(self):
        for axes_value in [1, 2], [[1], [2]], [[], []], 0:
            with self.cached_session():
                np_a = np.ones((3, 3))
                np_b = np.array([2, 3, 1])[None, None]
                np_ans = np.tensordot(np_a, np_b, axes_value)

                tf_a = tf.ones((3, 3), dtype=tf.float32)
                tf_b = tf.constant([2, 3, 1], dtype=tf.float32)[None, None]
                tf_ans = tensordot2.tensordot(tf, tf_a, tf_b, axes_value)

                self.assertAllEqual(tf_ans.shape, np_ans.shape)
                self.assertAllEqual(tf_ans, np_ans)
Exemplo n.º 4
0
 def test_invalid_shape(self):
     a = [[1, 2], [3, 4]]
     b = [[1, 2], [3, 4], [5, 6]]
     a_axes = [1]
     b_axes = [0]
     # Invalid static shapes.
     with self.assertRaises(tf.errors.InvalidArgumentError):
         tensordot2.tensordot(tf, a, b, (a_axes, b_axes))
     # Invalid dynamic shapes.
     # pylint: disable=not-context-manager
     with tf.compat.v1.Graph().as_default():
         with self.cached_session() as sess:
             with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
                                          "Matrix size-incompatible"):
                 a_ph = tf.compat.v1.placeholder(tf.float32)
                 b_ph = tf.compat.v1.placeholder(tf.float32)
                 axes_ph = tf.compat.v1.placeholder(tf.int32)
                 output = tensordot2.tensordot(tf, a_ph, b_ph, axes_ph)
                 _ = sess.run([output],
                              feed_dict={
                                  a_ph: a,
                                  b_ph: b,
                                  axes_ph: (a_axes, b_axes)
                              })
Exemplo n.º 5
0
def test_tensordot(dtype_, rank_a_, rank_b_, num_dims_):
    if not num_dims_ <= min(rank_a_, rank_b_):
        pytest.skip("Not a test")
    num_trials = min(30, num_dims_ * num_dims_)
    if dtype_ == np.float16:
        tol = 0.05
    elif dtype_ in (np.float32, np.complex64):
        tol = 1e-5
    else:
        tol = 1e-12
    for _ in range(num_trials):
        a_np, b_np, a_dims_np, b_dims_np = _generate_random_tensors_and_dims(
            dtype_, rank_a_, rank_b_, num_dims_)
        np_ans = np.tensordot(a_np, b_np, axes=(a_dims_np, b_dims_np))
        tf_ans = tensordot2.tensordot(tf, a_np, b_np, (a_dims_np, b_dims_np))
        np.testing.assert_allclose(tf_ans, np_ans, rtol=tol, atol=tol)
        assert tf_ans.shape == np_ans.shape
Exemplo n.º 6
0
def test_tensordot_scalar_axes(dtype_, rank_a_, rank_b_, num_dims_):
    if not num_dims_ <= min(rank_a_, rank_b_):
        pytest.skip("Not a test")
    if dtype_ == np.float16:
        tol = 0.05
    elif dtype_ in (np.float32, np.complex64):
        tol = 1e-5
    else:
        tol = 1e-12
    shape = [5] * num_dims_
    a_np = np.random.uniform(low=-1.0, high=1.0,
                             size=np.prod(shape)).reshape(shape).astype(dtype_)
    b_np = np.random.uniform(low=-1.0, high=1.0,
                             size=np.prod(shape)).reshape(shape).astype(dtype_)
    all_axes = [0, 1]
    if a_np.ndim > 2:
        all_axes.append(a_np.ndim - 1)
    for axes in all_axes:
        np_ans = np.tensordot(a_np, b_np, axes=axes)
        tf_ans = tensordot2.tensordot(tf, a_np, b_np, axes=axes)
        np.testing.assert_allclose(tf_ans, np_ans, rtol=tol, atol=tol)
        assert tf_ans.shape == np_ans.shape
 def tensordot(self, a: Tensor, b: Tensor, axes: Sequence[Sequence[int]]):
     return tensordot2.tensordot(tf, a, b, axes)
 def outer_product(self, tensor1: Tensor, tensor2: Tensor) -> Tensor:
     return tensordot2.tensordot(tf, tensor1, tensor2, 0)
Exemplo n.º 9
0
 def tensordot(self, a: Tensor, b: Tensor,
               axes: Union[int, Sequence[Sequence[int]]]) -> Tensor:
     return tensordot2.tensordot(tf, a, b, axes)