def concatenate_context_input(context_input, sequence_input):
  """Replicates `context_input` across all timesteps of `sequence_input`.

  Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
  This value is appended to `sequence_input` on dimension 2 and the result is
  returned.

  Args:
    context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
    sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
      padded_length, d0]`.

  Returns:
    A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
    d0 + d1]`.

  Raises:
    ValueError: If `sequence_input` does not have rank 3 or `context_input` does
      not have rank 2.
  """
  seq_rank_check = check_ops.assert_rank(
      sequence_input,
      3,
      message='sequence_input must have rank 3',
      data=[array_ops.shape(sequence_input)])
  seq_type_check = check_ops.assert_type(
      sequence_input,
      dtypes.float32,
      message='sequence_input must have dtype float32; got {}.'.format(
          sequence_input.dtype))
  ctx_rank_check = check_ops.assert_rank(
      context_input,
      2,
      message='context_input must have rank 2',
      data=[array_ops.shape(context_input)])
  ctx_type_check = check_ops.assert_type(
      context_input,
      dtypes.float32,
      message='context_input must have dtype float32; got {}.'.format(
          context_input.dtype))
  with ops.control_dependencies(
      [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
    padded_length = array_ops.shape(sequence_input)[1]
    tiled_context_input = array_ops.tile(
        array_ops.expand_dims(context_input, 1),
        array_ops.concat([[1], [padded_length], [1]], 0))
  return array_ops.concat([sequence_input, tiled_context_input], 2)
예제 #2
0
파일: rnn.py 프로젝트: yupbank/estimator
def _concatenate_context_input(sequence_input, context_input):
    """Replicates `context_input` across all timesteps of `sequence_input`.

  Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
  This value is appended to `sequence_input` on dimension 2 and the result is
  returned.

  Args:
    sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
      padded_length, d0]`.
    context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.

  Returns:
    A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
    d0 + d1]`.

  Raises:
    ValueError: If `sequence_input` does not have rank 3 or `context_input` does
      not have rank 2.
  """
    seq_rank_check = check_ops.assert_rank(
        sequence_input,
        3,
        message='sequence_input must have rank 3',
        data=[array_ops.shape(sequence_input)])
    seq_type_check = check_ops.assert_type(
        sequence_input,
        dtypes.float32,
        message='sequence_input must have dtype float32; got {}.'.format(
            sequence_input.dtype))
    ctx_rank_check = check_ops.assert_rank(
        context_input,
        2,
        message='context_input must have rank 2',
        data=[array_ops.shape(context_input)])
    ctx_type_check = check_ops.assert_type(
        context_input,
        dtypes.float32,
        message='context_input must have dtype float32; got {}.'.format(
            context_input.dtype))
    with ops.control_dependencies(
        [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
        padded_length = array_ops.shape(sequence_input)[1]
        tiled_context_input = array_ops.tile(
            array_ops.expand_dims(context_input, 1),
            array_ops.concat([[1], [padded_length], [1]], 0))
    return array_ops.concat([sequence_input, tiled_context_input], 2)
예제 #3
0
def assert_type(tensor: ragged_tensor.Ragged,
                tf_type,
                message=None,
                name=None):
    return check_ops.assert_type(tensor.flat_values,
                                 tf_type,
                                 message=message,
                                 name=name)
예제 #4
0
 def test_raises_when_wrong_type(self):
   floats = constant_op.constant([1.0, 2.0], dtype=dtypes.float16)
   with self.assertRaisesRegexp(TypeError, "must be of type.*float32"):
     check_ops.assert_type(floats, dtypes.float32)
예제 #5
0
 def test_doesnt_raise_when_correct_type(self):
   integers = constant_op.constant([1, 2], dtype=dtypes.int64)
   with ops.control_dependencies([
       check_ops.assert_type(integers, dtypes.int64)]):
     out = array_ops.identity(integers)
   self.evaluate(out)
예제 #6
0
 def testAssertType(self):
   x = ragged_factory_ops.constant([[1., 2.], [3.]])
   with ops.control_dependencies(
       [check_ops.assert_type(x, dtypes.float32)]):
     y = array_ops.identity(x)
   self.assertAllEqual(x, y)