def _assert_with_shape(self, tensor, expected_value, expected_shape, unexpected_shapes): for unexpected_shape in unexpected_shapes: self.assertRaises(ValueError, tensor_util.with_shape, unexpected_shape, tensor) pattern = ( r"\[Wrong shape for %s \[expected\] \[actual\].\] \[%s\] \[%s\]" % (tensor.name, " ".join([str(dim) for dim in unexpected_shape]), " ".join([str(dim) for dim in expected_shape]))) self.assertRaisesRegexp( errors_impl.OpError, re.compile(pattern), tensor_util.with_shape(constant_op.constant(unexpected_shape), tensor).eval) expected_placeholder = array_ops.placeholder(dtypes.float32) self.assertRaisesRegexp( errors_impl.OpError, re.compile(pattern), tensor_util.with_same_shape(expected_placeholder, tensor).eval, {expected_placeholder: np.ones(unexpected_shape)}) self.assertIs(tensor, tensor_util.with_shape(expected_shape, tensor)) self.assertIs( tensor, tensor_util.with_same_shape( constant_op.constant(1, shape=expected_shape), tensor)) tensor_with_shape = tensor_util.with_shape( constant_op.constant(expected_shape), tensor) np.testing.assert_array_equal(expected_value, tensor_with_shape.eval()) tensor_with_same_shape = tensor_util.with_same_shape( expected_placeholder, tensor) np.testing.assert_array_equal( expected_value, tensor_with_same_shape.eval( {expected_placeholder: np.ones(expected_shape)}))
def with_same_shape(old, new): """Check and set new tensor's shape.""" xla_compile = (os.environ["xla_compile"] == "true") if not xla_compile: if isinstance(old, tf.Tensor) and isinstance(new, tf.Tensor): return tensor_util.with_same_shape(old, new) return new
def with_same_shape(old, new): """Check and set new tensor's shape.""" if isinstance(old, tf.Tensor) and isinstance(new, tf.Tensor): return tensor_util.with_same_shape(old, new) return new
def assert_state_is_compatible(expected_state, state): """Asserts that states are compatible. Args: expected_state: The reference state. state: The state that must be compatible with :obj:`expected_state`. Raises: ValueError: if the states are incompatible. """ # Check structure compatibility. nest.assert_same_structure(expected_state, state) # Check shape compatibility. expected_state_flat = nest.flatten(expected_state) state_flat = nest.flatten(state) for x, y in zip(expected_state_flat, state_flat): if tensor_util.is_tensor(x): with_same_shape(x, y)
def with_same_shape(old, new): """Check and set new tensor's shape.""" if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor): if not context.executing_eagerly(): return tensor_util.with_same_shape(old, new) else: if old.shape.as_list() != new.shape.as_list(): raise ValueError( "The shape of the AttentionWrapperState is " "expected to be same as the one to clone. " "self.shape: %s, input.shape: %s" % (old.shape, new.shape)) return new return new
def _assert_with_shape(self, tensor, expected_value, expected_shape, unexpected_shapes): for unexpected_shape in unexpected_shapes: self.assertRaises(ValueError, tensor_util.with_shape, unexpected_shape, tensor) pattern = ( r"\[Wrong shape for %s \[expected\] \[actual\].\] \[%s\] \[%s\]" % (tensor.name, " ".join([str(dim) for dim in unexpected_shape]), " ".join([str(dim) for dim in expected_shape]))) self.assertRaisesRegexp(errors_impl.OpError, re.compile(pattern), tensor_util.with_shape( constant_op.constant(unexpected_shape), tensor).eval) expected_placeholder = array_ops.placeholder(dtypes.float32) self.assertRaisesRegexp(errors_impl.OpError, re.compile(pattern), tensor_util.with_same_shape(expected_placeholder, tensor).eval, {expected_placeholder: np.ones(unexpected_shape)}) self.assertIs(tensor, tensor_util.with_shape(expected_shape, tensor)) self.assertIs( tensor, tensor_util.with_same_shape( constant_op.constant( 1, shape=expected_shape), tensor)) tensor_with_shape = tensor_util.with_shape( constant_op.constant(expected_shape), tensor) np.testing.assert_array_equal(expected_value, tensor_with_shape.eval()) tensor_with_same_shape = tensor_util.with_same_shape(expected_placeholder, tensor) np.testing.assert_array_equal(expected_value, tensor_with_same_shape.eval({ expected_placeholder: np.ones(expected_shape) }))
def with_same_shape(old, new): """Check and set new tensor's shape.""" if isinstance(old, tf.Tensor) and isinstance(new, tf.Tensor): return tensor_util.with_same_shape(old, new) return new