def test_invalid_axes(self): # pylint: disable=not-context-manager with tf.compat.v1.Graph().as_default(): a = [[1, 2], [3, 4]] b = [[1, 2], [3, 4]] # Invalid static axes. for axes_value in -1, 3, [1], [[1]], [[1], [0, 1]]: with self.assertRaises(ValueError): tensordot2.tensordot(tf, a, b, axes_value) with self.assertRaises(IndexError): tensordot2.tensordot(tf, a, b, [[0], [7]]) # Invalid dynamic axes. a_ph = tf.compat.v1.placeholder(tf.float32) b_ph = tf.compat.v1.placeholder(tf.float32) axes_ph = tf.compat.v1.placeholder(tf.int32) output = tensordot2.tensordot(tf, a_ph, b_ph, axes_ph) # Note: We don't support scalar Tensor values for axes. for axes_value in 1, [1], [0, 1], [[1]], [[0, 1]], [[0], [7]]: with self.cached_session() as sess: with self.assertRaises(tf.errors.InvalidArgumentError): _ = sess.run([output], feed_dict={ a_ph: a, b_ph: b, axes_ph: axes_value })
def test_partial_shape_inference(self): # pylint: disable=not-context-manager with tf.compat.v1.Graph().as_default(): for axes in ([1], [0]), 1: a = tf.compat.v1.placeholder(tf.float32) b = tf.compat.v1.placeholder(tf.float32) output = tensordot2.tensordot(tf, a, b, axes) self.assertEqual(output.get_shape().ndims, None) a.set_shape([None, 2]) b.set_shape([2, 3]) output = tensordot2.tensordot(tf, a, b, axes) output_shape = output.get_shape() self.assertEqual(output_shape.ndims, 2) output_shape = output_shape.as_list() self.assertEqual(output_shape[0], None) self.assertEqual(output_shape[1], 3) a = tf.compat.v1.placeholder(tf.float32) b = tf.compat.v1.placeholder(tf.float32) a.set_shape([2, 2]) b.set_shape([2, None]) output = tensordot2.tensordot(tf, a, b, axes) output_shape = output.get_shape() self.assertEqual(output_shape.ndims, 2) output_shape = output_shape.as_list() self.assertEqual(output_shape[0], 2) self.assertEqual(output_shape[1], None)
def test_valid_axis(self): for axes_value in [1, 2], [[1], [2]], [[], []], 0: with self.cached_session(): np_a = np.ones((3, 3)) np_b = np.array([2, 3, 1])[None, None] np_ans = np.tensordot(np_a, np_b, axes_value) tf_a = tf.ones((3, 3), dtype=tf.float32) tf_b = tf.constant([2, 3, 1], dtype=tf.float32)[None, None] tf_ans = tensordot2.tensordot(tf, tf_a, tf_b, axes_value) self.assertAllEqual(tf_ans.shape, np_ans.shape) self.assertAllEqual(tf_ans, np_ans)
def test_invalid_shape(self): a = [[1, 2], [3, 4]] b = [[1, 2], [3, 4], [5, 6]] a_axes = [1] b_axes = [0] # Invalid static shapes. with self.assertRaises(tf.errors.InvalidArgumentError): tensordot2.tensordot(tf, a, b, (a_axes, b_axes)) # Invalid dynamic shapes. # pylint: disable=not-context-manager with tf.compat.v1.Graph().as_default(): with self.cached_session() as sess: with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, "Matrix size-incompatible"): a_ph = tf.compat.v1.placeholder(tf.float32) b_ph = tf.compat.v1.placeholder(tf.float32) axes_ph = tf.compat.v1.placeholder(tf.int32) output = tensordot2.tensordot(tf, a_ph, b_ph, axes_ph) _ = sess.run([output], feed_dict={ a_ph: a, b_ph: b, axes_ph: (a_axes, b_axes) })
def test_tensordot(dtype_, rank_a_, rank_b_, num_dims_): if not num_dims_ <= min(rank_a_, rank_b_): pytest.skip("Not a test") num_trials = min(30, num_dims_ * num_dims_) if dtype_ == np.float16: tol = 0.05 elif dtype_ in (np.float32, np.complex64): tol = 1e-5 else: tol = 1e-12 for _ in range(num_trials): a_np, b_np, a_dims_np, b_dims_np = _generate_random_tensors_and_dims( dtype_, rank_a_, rank_b_, num_dims_) np_ans = np.tensordot(a_np, b_np, axes=(a_dims_np, b_dims_np)) tf_ans = tensordot2.tensordot(tf, a_np, b_np, (a_dims_np, b_dims_np)) np.testing.assert_allclose(tf_ans, np_ans, rtol=tol, atol=tol) assert tf_ans.shape == np_ans.shape
def test_tensordot_scalar_axes(dtype_, rank_a_, rank_b_, num_dims_): if not num_dims_ <= min(rank_a_, rank_b_): pytest.skip("Not a test") if dtype_ == np.float16: tol = 0.05 elif dtype_ in (np.float32, np.complex64): tol = 1e-5 else: tol = 1e-12 shape = [5] * num_dims_ a_np = np.random.uniform(low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_) b_np = np.random.uniform(low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_) all_axes = [0, 1] if a_np.ndim > 2: all_axes.append(a_np.ndim - 1) for axes in all_axes: np_ans = np.tensordot(a_np, b_np, axes=axes) tf_ans = tensordot2.tensordot(tf, a_np, b_np, axes=axes) np.testing.assert_allclose(tf_ans, np_ans, rtol=tol, atol=tol) assert tf_ans.shape == np_ans.shape
def tensordot(self, a: Tensor, b: Tensor, axes: Sequence[Sequence[int]]): return tensordot2.tensordot(tf, a, b, axes)
def outer_product(self, tensor1: Tensor, tensor2: Tensor) -> Tensor: return tensordot2.tensordot(tf, tensor1, tensor2, 0)
def tensordot(self, a: Tensor, b: Tensor, axes: Union[int, Sequence[Sequence[int]]]) -> Tensor: return tensordot2.tensordot(tf, a, b, axes)