コード例 #1
0
ファイル: tensor_util_test.py プロジェクト: imdone/tensorflow
    def _assert_with_shape(self, tensor, expected_value, expected_shape,
                           unexpected_shapes):
        for unexpected_shape in unexpected_shapes:
            self.assertRaises(ValueError, tensor_util.with_shape,
                              unexpected_shape, tensor)
            pattern = (
                r"\[Wrong shape for %s \[expected\] \[actual\].\] \[%s\] \[%s\]"
                %
                (tensor.name, " ".join([str(dim) for dim in unexpected_shape]),
                 " ".join([str(dim) for dim in expected_shape])))
            self.assertRaisesRegexp(
                errors_impl.OpError, re.compile(pattern),
                tensor_util.with_shape(constant_op.constant(unexpected_shape),
                                       tensor).eval)
            expected_placeholder = array_ops.placeholder(dtypes.float32)
            self.assertRaisesRegexp(
                errors_impl.OpError, re.compile(pattern),
                tensor_util.with_same_shape(expected_placeholder, tensor).eval,
                {expected_placeholder: np.ones(unexpected_shape)})

        self.assertIs(tensor, tensor_util.with_shape(expected_shape, tensor))
        self.assertIs(
            tensor,
            tensor_util.with_same_shape(
                constant_op.constant(1, shape=expected_shape), tensor))
        tensor_with_shape = tensor_util.with_shape(
            constant_op.constant(expected_shape), tensor)
        np.testing.assert_array_equal(expected_value, tensor_with_shape.eval())
        tensor_with_same_shape = tensor_util.with_same_shape(
            expected_placeholder, tensor)
        np.testing.assert_array_equal(
            expected_value,
            tensor_with_same_shape.eval(
                {expected_placeholder: np.ones(expected_shape)}))
コード例 #2
0
 def with_same_shape(old, new):
     """Check and set new tensor's shape."""
     xla_compile = (os.environ["xla_compile"] == "true")
     if not xla_compile:
         if isinstance(old, tf.Tensor) and isinstance(new, tf.Tensor):
             return tensor_util.with_same_shape(old, new)
     return new
コード例 #3
0
        def with_same_shape(old, new):
            """Check and set new tensor's shape."""

            if isinstance(old, tf.Tensor) and isinstance(new, tf.Tensor):
                return tensor_util.with_same_shape(old, new)

            return new
コード例 #4
0
def assert_state_is_compatible(expected_state, state):
    """Asserts that states are compatible.

    Args:
        expected_state: The reference state.
        state: The state that must be compatible with :obj:`expected_state`.

    Raises:
      ValueError: if the states are incompatible.
    """
    # Check structure compatibility.
    nest.assert_same_structure(expected_state, state)

    # Check shape compatibility.
    expected_state_flat = nest.flatten(expected_state)
    state_flat = nest.flatten(state)

    for x, y in zip(expected_state_flat, state_flat):
        if tensor_util.is_tensor(x):
            with_same_shape(x, y)
コード例 #5
0
 def with_same_shape(old, new):
     """Check and set new tensor's shape."""
     if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor):
         if not context.executing_eagerly():
             return tensor_util.with_same_shape(old, new)
         else:
             if old.shape.as_list() != new.shape.as_list():
                 raise ValueError(
                     "The shape of the AttentionWrapperState is "
                     "expected to be same as the one to clone. "
                     "self.shape: %s, input.shape: %s" %
                     (old.shape, new.shape))
             return new
     return new
コード例 #6
0
  def _assert_with_shape(self, tensor, expected_value, expected_shape,
                         unexpected_shapes):
    for unexpected_shape in unexpected_shapes:
      self.assertRaises(ValueError, tensor_util.with_shape, unexpected_shape,
                        tensor)
      pattern = (
          r"\[Wrong shape for %s \[expected\] \[actual\].\] \[%s\] \[%s\]" %
          (tensor.name, " ".join([str(dim) for dim in unexpected_shape]),
           " ".join([str(dim) for dim in expected_shape])))
      self.assertRaisesRegexp(errors_impl.OpError,
                              re.compile(pattern),
                              tensor_util.with_shape(
                                  constant_op.constant(unexpected_shape),
                                  tensor).eval)
      expected_placeholder = array_ops.placeholder(dtypes.float32)
      self.assertRaisesRegexp(errors_impl.OpError,
                              re.compile(pattern),
                              tensor_util.with_same_shape(expected_placeholder,
                                                          tensor).eval,
                              {expected_placeholder: np.ones(unexpected_shape)})

    self.assertIs(tensor, tensor_util.with_shape(expected_shape, tensor))
    self.assertIs(
        tensor,
        tensor_util.with_same_shape(
            constant_op.constant(
                1, shape=expected_shape), tensor))
    tensor_with_shape = tensor_util.with_shape(
        constant_op.constant(expected_shape), tensor)
    np.testing.assert_array_equal(expected_value, tensor_with_shape.eval())
    tensor_with_same_shape = tensor_util.with_same_shape(expected_placeholder,
                                                         tensor)
    np.testing.assert_array_equal(expected_value,
                                  tensor_with_same_shape.eval({
                                      expected_placeholder:
                                          np.ones(expected_shape)
                                  }))
コード例 #7
0
ファイル: copynet.py プロジェクト: pteixei/CopyNet
 def with_same_shape(old, new):
     """Check and set new tensor's shape."""
     if isinstance(old, tf.Tensor) and isinstance(new, tf.Tensor):
         return tensor_util.with_same_shape(old, new)
     return new