예제 #1
0
def _ParseSingleSequenceExampleShape(op):  # pylint: disable=invalid-name
  """Shape function for the ParseExample op."""
  op.inputs[0].get_shape().with_rank(0)  # input
  # feature_list_dense_missing_assumed_empty
  op.inputs[1].get_shape().with_rank(1)
  num_context_sparse = op.get_attr("Ncontext_sparse")
  num_context_dense = op.get_attr("Ncontext_dense")
  num_feature_list_dense = op.get_attr("Nfeature_list_dense")
  context_dense_shapes = op.get_attr("context_dense_shapes")
  num_feature_list_sparse = op.get_attr("Nfeature_list_sparse")
  feature_list_dense_shapes = op.get_attr("feature_list_dense_shapes")
  context_sparse_index_shapes = [
      tensor_shape.matrix(None, 1) for _ in range(num_context_sparse)]
  context_sparse_value_shapes = [
      tensor_shape.vector(None) for _ in range(num_context_sparse)]
  context_sparse_shape_shapes = [
      tensor_shape.vector(1) for _ in range(num_context_sparse)]
  context_dense_shapes = [
      tensor_shape.TensorShape(dense_shape)
      for dense_shape in context_dense_shapes]
  feature_list_sparse_index_shapes = [
      tensor_shape.matrix(None, 2) for _ in range(num_feature_list_sparse)]
  feature_list_sparse_value_shapes = [
      tensor_shape.vector(None) for _ in range(num_feature_list_sparse)]
  feature_list_sparse_shape_shapes = [
      tensor_shape.vector(2) for _ in range(num_feature_list_sparse)]
  feature_list_dense_shapes = [
      tensor_shape.vector(None).concatenate(dense_shape)
      for dense_shape in feature_list_dense_shapes]
  assert num_context_dense == len(context_dense_shapes)
  assert num_feature_list_dense == len(feature_list_dense_shapes)
  return (context_sparse_index_shapes + context_sparse_value_shapes +
          context_sparse_shape_shapes + context_dense_shapes +
          feature_list_sparse_index_shapes + feature_list_sparse_value_shapes +
          feature_list_sparse_shape_shapes + feature_list_dense_shapes)
  def testShapes(self):
    fdef = self._build_function_def()

    g = function_def_to_graph.function_def_to_graph(fdef)
    self.assertIsNone(g.inputs[0].shape.dims)  # Unknown dims.
    self.assertIsNone(g.inputs[1].shape.dims)  # Unknown dims.
    self.assertIsNone(g.outputs[0].shape.dims)  # Unknown dims.
    self.assertIsNone(g.outputs[1].shape.dims)  # Unknown dims.

    g = function_def_to_graph.function_def_to_graph(
        fdef, input_shapes=[tensor_shape.vector(5),
                            tensor_shape.vector(5)])
    self.assertSequenceEqual(g.inputs[0].shape.dims, [5])
    self.assertSequenceEqual(g.inputs[1].shape.dims, [5])
    self.assertSequenceEqual(g.outputs[0].shape.dims, [5])
    self.assertSequenceEqual(g.outputs[1].shape.dims, [5])

    g = function_def_to_graph.function_def_to_graph(
        fdef, input_shapes=[None, tensor_shape.matrix(5, 7)])
    self.assertIsNone(g.inputs[0].shape.dims)
    self.assertSequenceEqual(g.inputs[1].shape.dims, [5, 7])
    self.assertSequenceEqual(g.outputs[0].shape.dims, [5, 7])
    self.assertSequenceEqual(g.outputs[1].shape.dims, [5, 7])

    # Should raise a ValueError if the length of input_shapes does not match
    # the number of input args in FunctionDef.signature.input_arg.
    with self.assertRaises(ValueError):
      g = function_def_to_graph.function_def_to_graph(
          fdef, input_shapes=[tensor_shape.matrix(5, 7)])
예제 #3
0
def _ParseSingleSequenceExampleShape(op):
    """Shape function for the ParseExample op."""
    op.inputs[0].get_shape().with_rank(0)  # input
    # feature_list_dense_missing_assumed_empty
    op.inputs[1].get_shape().with_rank(1)
    num_context_sparse = op.get_attr("Ncontext_sparse")
    num_context_dense = op.get_attr("Ncontext_dense")
    num_feature_list_dense = op.get_attr("Nfeature_list_dense")
    context_dense_shapes = op.get_attr("context_dense_shapes")
    num_feature_list_sparse = op.get_attr("Nfeature_list_sparse")
    feature_list_dense_shapes = op.get_attr("feature_list_dense_shapes")
    context_sparse_index_shapes = [tensor_shape.matrix(None, 1) for _ in range(num_context_sparse)]
    context_sparse_value_shapes = [tensor_shape.vector(None) for _ in range(num_context_sparse)]
    context_sparse_shape_shapes = [tensor_shape.vector(1) for _ in range(num_context_sparse)]
    context_dense_shapes = [tensor_shape.TensorShape(dense_shape) for dense_shape in context_dense_shapes]
    feature_list_sparse_index_shapes = [tensor_shape.matrix(None, 2) for _ in range(num_feature_list_sparse)]
    feature_list_sparse_value_shapes = [tensor_shape.vector(None) for _ in range(num_feature_list_sparse)]
    feature_list_sparse_shape_shapes = [tensor_shape.vector(2) for _ in range(num_feature_list_sparse)]
    feature_list_dense_shapes = [
        tensor_shape.vector(None).concatenate(dense_shape) for dense_shape in feature_list_dense_shapes
    ]
    assert num_context_dense == len(context_dense_shapes)
    assert num_feature_list_dense == len(feature_list_dense_shapes)
    return (
        context_sparse_index_shapes
        + context_sparse_value_shapes
        + context_sparse_shape_shapes
        + context_dense_shapes
        + feature_list_sparse_index_shapes
        + feature_list_sparse_value_shapes
        + feature_list_sparse_shape_shapes
        + feature_list_dense_shapes
    )
예제 #4
0
    def testShapes(self):
        fdef = self._build_function_def()

        g = function_def_to_graph.function_def_to_graph(fdef)
        self.assertIsNone(g.inputs[0].shape.dims)  # Unknown dims.
        self.assertIsNone(g.inputs[1].shape.dims)  # Unknown dims.
        self.assertIsNone(g.outputs[0].shape.dims)  # Unknown dims.
        self.assertIsNone(g.outputs[1].shape.dims)  # Unknown dims.

        g = function_def_to_graph.function_def_to_graph(
            fdef,
            input_shapes=[tensor_shape.vector(5),
                          tensor_shape.vector(5)])
        self.assertSequenceEqual(g.inputs[0].shape.dims, [5])
        self.assertSequenceEqual(g.inputs[1].shape.dims, [5])
        self.assertSequenceEqual(g.outputs[0].shape.dims, [5])
        self.assertSequenceEqual(g.outputs[1].shape.dims, [5])

        g = function_def_to_graph.function_def_to_graph(
            fdef, input_shapes=[None, tensor_shape.matrix(5, 7)])
        self.assertIsNone(g.inputs[0].shape.dims)
        self.assertSequenceEqual(g.inputs[1].shape.dims, [5, 7])
        self.assertSequenceEqual(g.outputs[0].shape.dims, [5, 7])
        self.assertSequenceEqual(g.outputs[1].shape.dims, [5, 7])

        # Should raise a ValueError if the length of input_shapes does not match
        # the number of input args in FunctionDef.signature.input_arg.
        with self.assertRaises(ValueError):
            g = function_def_to_graph.function_def_to_graph(
                fdef, input_shapes=[tensor_shape.matrix(5, 7)])
예제 #5
0
def _CTCGreedyDecoderShape(op):
  """Shape function for the CTCGreedyDecoder op."""
  inputs_shape = op.inputs[0].get_shape().with_rank(3)
  sequence_length_shape = op.inputs[1].get_shape().with_rank(1)
  # merge batch_size
  sequence_length_shape[0].merge_with(inputs_shape[1])
  inputs_shape[1].merge_with(sequence_length_shape[0])
  batch_size = inputs_shape[1]
  # decoded_indices, decoded_values, decoded_shape, log_probability
  return [tensor_shape.matrix(None, 2),
          tensor_shape.vector(None),
          tensor_shape.vector(2),
          tensor_shape.matrix(batch_size, 1)]
예제 #6
0
def _CTCGreedyDecoderShape(op):
  """Shape function for the CTCGreedyDecoder op."""
  inputs_shape = op.inputs[0].get_shape().with_rank(3)
  sequence_length_shape = op.inputs[1].get_shape().with_rank(1)
  # merge batch_size
  sequence_length_shape[0].merge_with(inputs_shape[1])
  inputs_shape[1].merge_with(sequence_length_shape[0])
  batch_size = inputs_shape[1]
  # decoded_indices, decoded_values, decoded_shape, log_probability
  return [tensor_shape.matrix(None, 2),
          tensor_shape.vector(None),
          tensor_shape.vector(2),
          tensor_shape.matrix(batch_size, 1)]
예제 #7
0
    def testBroadcast_many_dimensions(self):
        unknown = tensor_shape.unknown_shape()
        shape_0 = tensor_shape.scalar()
        shape_1 = tensor_shape.vector(1)
        shape_4 = tensor_shape.vector(4)
        shape_1x4 = tensor_shape.matrix(1, 4)
        shape_4x1 = tensor_shape.matrix(4, 1)
        shape_3x4 = tensor_shape.matrix(3, 4)
        shape_4x3 = tensor_shape.matrix(4, 3)

        # Tensors with same shape should have the same broadcast result.
        for shape in (shape_0, shape_1, shape_4, shape_1x4, shape_4x1,
                      shape_3x4, shape_4x3):
            self._assert_broadcast(expected=shape, shape1=shape, shape2=shape)

        # [] and [1] act like identity.
        for identity in (shape_0, shape_1):
            for shape in (shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3):
                self._assert_broadcast(expected=shape,
                                       shape1=identity,
                                       shape2=shape)

        # Unknown in, unknown out.
        for shape in (shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3):
            self._assert_broadcast(expected=unknown,
                                   shape1=shape,
                                   shape2=unknown)

        self._assert_broadcast(expected=shape_1x4,
                               shape1=shape_4,
                               shape2=shape_1x4)
        shape_4x4 = tensor_shape.matrix(4, 4)
        self._assert_broadcast(expected=shape_4x4,
                               shape1=shape_4,
                               shape2=shape_4x1)
        self._assert_broadcast(expected=shape_3x4,
                               shape1=shape_4,
                               shape2=shape_3x4)
        self._assert_incompatible_broadcast(shape1=shape_4, shape2=shape_4x3)
        self._assert_broadcast(expected=shape_4x4,
                               shape1=shape_1x4,
                               shape2=shape_4x1)
        self._assert_broadcast(expected=shape_3x4,
                               shape1=shape_1x4,
                               shape2=shape_3x4)
        self._assert_incompatible_broadcast(shape1=shape_1x4, shape2=shape_4x3)
        self._assert_incompatible_broadcast(shape1=shape_4x1, shape2=shape_3x4)
        self._assert_broadcast(expected=shape_4x3,
                               shape1=shape_4x1,
                               shape2=shape_4x3)
        self._assert_incompatible_broadcast(shape1=shape_3x4, shape2=shape_4x3)
예제 #8
0
def _SparseConcatShape(op):
    """Shape function for SparseConcat op."""
    num_inputs = int(op.get_attr("N"))

    # TF flattens and concatenates all list inputs, so reconstruct the lists here.
    ind_shapes = [
        ind.get_shape().with_rank(2) for ind in op.inputs[0:num_inputs]
    ]
    val_shapes = [
        val.get_shape().with_rank(1)
        for val in op.inputs[num_inputs:2 * num_inputs]
    ]
    shape_shapes = [
        shape.get_shape().with_rank(1) for shape in op.inputs[2 * num_inputs:]
    ]

    output_ind_rows = tensor_shape.Dimension(0)
    output_ind_cols = tensor_shape.Dimension(None)
    output_val_elems = tensor_shape.Dimension(0)
    output_shape_shape = tensor_shape.TensorShape(None)

    for i in xrange(num_inputs):
        num_elems_i = ind_shapes[i][0].merge_with(val_shapes[i][0])
        output_ind_rows += num_elems_i
        output_ind_cols = output_ind_cols.merge_with(ind_shapes[i][1])
        output_val_elems += num_elems_i
        output_shape_shape = output_shape_shape.merge_with(shape_shapes[i])

    output_ind_shape = tensor_shape.matrix(output_ind_rows, output_ind_cols)
    output_val_shape = tensor_shape.vector(output_val_elems)

    return [output_ind_shape, output_val_shape, output_shape_shape]
예제 #9
0
def _SerializeManySparseShape(op):  # pylint: disable=invalid-name
    """Shape function for SerializeSparse op."""
    op.inputs[0].get_shape().with_rank(2)
    op.inputs[1].get_shape().with_rank(1)
    op.inputs[2].get_shape().with_rank(1)

    return [tensor_shape.matrix(None, 3)]
예제 #10
0
파일: rnn.py 프로젝트: 4chin/tensorflow
def _reverse_seq(input_seq, lengths):
  """Reverse a list of Tensors up to specified lengths.

  Args:
    input_seq: Sequence of seq_len tensors of dimension (batch_size, depth)
    lengths:   A tensor of dimension batch_size, containing lengths for each
               sequence in the batch. If "None" is specified, simply reverses
               the list.

  Returns:
    time-reversed sequence
  """
  if lengths is None:
    return list(reversed(input_seq))

  input_shape = tensor_shape.matrix(None, None)
  for input_ in input_seq:
    input_shape.merge_with(input_.get_shape())
    input_.set_shape(input_shape)

  # Join into (time, batch_size, depth)
  s_joined = array_ops.pack(input_seq)

  # TODO(schuster, ebrevdo): Remove cast when reverse_sequence takes int32
  if lengths is not None:
    lengths = math_ops.to_int64(lengths)

  # Reverse along dimension 0
  s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1)
  # Split again into list
  result = array_ops.unpack(s_reversed)
  for r in result:
    r.set_shape(input_shape)
  return result
예제 #11
0
def _reverse_seq(input_seq, lengths):
    """Reverse a list of Tensors up to specified lengths.

  Args:
    input_seq: Sequence of seq_len tensors of dimension (batch_size, depth)
    lengths:   A tensor of dimension batch_size, containing lengths for each
               sequence in the batch. If "None" is specified, simply reverses
               the list.

  Returns:
    time-reversed sequence
  """
    if lengths is None:
        return list(reversed(input_seq))

    input_shape = tensor_shape.matrix(None, None)
    for input_ in input_seq:
        input_shape.merge_with(input_.get_shape())
        input_.set_shape(input_shape)

    # Join into (time, batch_size, depth)
    s_joined = array_ops.pack(input_seq)

    # TODO(schuster, ebrevdo): Remove cast when reverse_sequence takes int32
    if lengths is not None:
        lengths = math_ops.to_int64(lengths)

    # Reverse along dimension 0
    s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1)
    # Split again into list
    result = array_ops.unpack(s_reversed)
    for r in result:
        r.set_shape(input_shape)
    return result
def _ComputeAccidentalHitsShape(op):
    num_true = op.get_attr("num_true")
    # Validate that the input shape matches the attrs, even though it
    # does not influence the shape of the output.
    true_candidates_shape = op.inputs[0].get_shape().merge_with(tensor_shape.matrix(None, num_true))
    output_shape = tensor_shape.vector(None)
    return [output_shape] * 3
예제 #13
0
def _SparseConcatShape(op):
  """Shape function for SparseConcat op."""
  num_inputs = int(op.get_attr("N"))

  # TF flattens and concatenates all list inputs, so reconstruct the lists here.
  ind_shapes = [ind.get_shape().with_rank(2) for ind in op.inputs[0:num_inputs]]
  val_shapes = [val.get_shape().with_rank(1)
                for val in op.inputs[num_inputs:2 * num_inputs]]
  shape_shapes = [shape.get_shape().with_rank(1)
                  for shape in op.inputs[2 * num_inputs:]]

  output_ind_rows = tensor_shape.Dimension(0)
  output_ind_cols = tensor_shape.Dimension(None)
  output_val_elems = tensor_shape.Dimension(0)
  output_shape_shape = tensor_shape.TensorShape(None)

  for i in xrange(num_inputs):
    num_elems_i = ind_shapes[i][0].merge_with(val_shapes[i][0])
    output_ind_rows += num_elems_i
    output_ind_cols = output_ind_cols.merge_with(ind_shapes[i][1])
    output_val_elems += num_elems_i
    output_shape_shape = output_shape_shape.merge_with(shape_shapes[i])

  output_ind_shape = tensor_shape.matrix(output_ind_rows, output_ind_cols)
  output_val_shape = tensor_shape.vector(output_val_elems)

  return [output_ind_shape, output_val_shape, output_shape_shape]
예제 #14
0
def _SerializeManySparseShape(op):  # pylint: disable=invalid-name
  """Shape function for SerializeSparse op."""
  op.inputs[0].get_shape().with_rank(2)
  op.inputs[1].get_shape().with_rank(1)
  op.inputs[2].get_shape().with_rank(1)

  return [tensor_shape.matrix(None, 3)]
예제 #15
0
 def testHelpers(self):
   tensor_shape.TensorShape([]).assert_is_compatible_with(
       tensor_shape.scalar())
   tensor_shape.TensorShape([37]).assert_is_compatible_with(
       tensor_shape.vector(37))
   tensor_shape.TensorShape(
       [94, 43]).assert_is_compatible_with(tensor_shape.matrix(94, 43))
예제 #16
0
 def testHelpers(self):
     tensor_shape.TensorShape([]).assert_is_compatible_with(
         tensor_shape.scalar())
     tensor_shape.TensorShape([37]).assert_is_compatible_with(
         tensor_shape.vector(37))
     tensor_shape.TensorShape([94, 43]).assert_is_compatible_with(
         tensor_shape.matrix(94, 43))
예제 #17
0
def _CandidateSamplerShape(op):
  true_classes_shape = op.inputs[0].get_shape().with_rank(2)
  batch_size = true_classes_shape[0]
  num_sampled = op.get_attr("num_sampled")
  num_true = op.get_attr("num_true")
  return [tensor_shape.vector(num_sampled),
          tensor_shape.matrix(batch_size, num_true),
          tensor_shape.vector(num_sampled)]
예제 #18
0
def _ComputeAccidentalHitsShape(op):
  num_true = op.get_attr("num_true")
  # Validate that the input shape matches the attrs, even though it
  # does not influence the shape of the output.
  true_candidates_shape = op.inputs[0].get_shape().merge_with(
      tensor_shape.matrix(None, num_true))
  output_shape = tensor_shape.vector(None)
  return [output_shape] * 3
예제 #19
0
def _DeserializeSparseShape(op):  # pylint: disable=invalid-name
  """Shape function for DeserializeManySparse op."""
  serialized_sparse_shape = op.inputs[0].get_shape().with_rank(2)
  serialized_sparse_shape.merge_with(
      tensor_shape.TensorShape([None, 3]))

  return [tensor_shape.matrix(None, None),
          tensor_shape.vector(None),
          tensor_shape.vector(None)]
예제 #20
0
def _SparseTensorDenseMatMulShape(op):  # pylint: disable=invalid-name
  """Shape function for SparseTensorDenseMatMul op."""
  adjoint_b = op.get_attr("adjoint_b")
  op.inputs[0].get_shape().assert_has_rank(2)  # a_indices
  op.inputs[1].get_shape().assert_has_rank(1)  # a_values
  op.inputs[2].get_shape().merge_with(tensor_shape.vector(2))  # a_shape
  b_shape = op.inputs[3].get_shape().with_rank(2)
  output_shape_right = b_shape[0] if adjoint_b else b_shape[1]
  return [tensor_shape.matrix(None, output_shape_right)]
예제 #21
0
def _DeserializeSparseShape(op):  # pylint: disable=invalid-name
  """Shape function for DeserializeManySparse op."""
  serialized_sparse_shape = op.inputs[0].get_shape().with_rank(2)
  serialized_sparse_shape.merge_with(
      tensor_shape.TensorShape([None, 3]))

  return [tensor_shape.matrix(None, None),
          tensor_shape.vector(None),
          tensor_shape.vector(None)]
예제 #22
0
def _SparseTensorDenseMatMulShape(op):  # pylint: disable=invalid-name
    """Shape function for SparseTensorDenseMatMul op."""
    adjoint_b = op.get_attr("adjoint_b")
    op.inputs[0].get_shape().assert_has_rank(2)  # a_indices
    op.inputs[1].get_shape().assert_has_rank(1)  # a_values
    op.inputs[2].get_shape().merge_with(tensor_shape.vector(2))  # a_shape
    b_shape = op.inputs[3].get_shape().with_rank(2)
    output_shape_right = b_shape[0] if adjoint_b else b_shape[1]
    return [tensor_shape.matrix(None, output_shape_right)]
예제 #23
0
def _CTCBeamSearchDecoderShape(op):
  """Shape function for the CTCBeamSearchDecoder op."""
  inputs_shape = op.inputs[0].get_shape().with_rank(3)
  sequence_length_shape = op.inputs[1].get_shape().with_rank(1)
  # merge batch size
  sequence_length_shape[0].merge_with(inputs_shape[1])
  inputs_shape[1].merge_with(sequence_length_shape[0])
  batch_size = inputs_shape[1]
  top_paths = op.get_attr("top_paths")

  # first the decoded indices
  output_shapes = [tensor_shape.matrix(None, 2) for _ in range(top_paths)]
  # next the decoded values
  output_shapes.extend([tensor_shape.vector(None) for _ in range(top_paths)])
  # the shapes of the decoded values
  output_shapes.extend([tensor_shape.vector(2)] * top_paths)
  # the log_probability matrix
  output_shapes.append(tensor_shape.matrix(batch_size, top_paths))
  return output_shapes
예제 #24
0
def _CTCBeamSearchDecoderShape(op):
    """Shape function for the CTCBeamSearchDecoder op."""
    inputs_shape = op.inputs[0].get_shape().with_rank(3)
    sequence_length_shape = op.inputs[1].get_shape().with_rank(1)
    # merge batch size
    sequence_length_shape[0].merge_with(inputs_shape[1])
    inputs_shape[1].merge_with(sequence_length_shape[0])
    batch_size = inputs_shape[1]
    top_paths = op.get_attr("top_paths")

    # first the decoded indices
    output_shapes = [tensor_shape.matrix(None, 2) for _ in range(top_paths)]
    # next the decoded values
    output_shapes.extend([tensor_shape.vector(None) for _ in range(top_paths)])
    # the shapes of the decoded values
    output_shapes.extend([tensor_shape.vector(2)] * top_paths)
    # the log_probability matrix
    output_shapes.append(tensor_shape.matrix(batch_size, top_paths))
    return output_shapes
예제 #25
0
  def testStr(self):
    self.assertEqual("<unknown>", str(tensor_shape.unknown_shape()))
    self.assertEqual("(?,)", str(tensor_shape.unknown_shape(ndims=1)))
    self.assertEqual("(?, ?)", str(tensor_shape.unknown_shape(ndims=2)))
    self.assertEqual("(?, ?, ?)", str(tensor_shape.unknown_shape(ndims=3)))

    self.assertEqual("()", str(tensor_shape.scalar()))
    self.assertEqual("(7,)", str(tensor_shape.vector(7)))
    self.assertEqual("(3, 8)", str(tensor_shape.matrix(3, 8)))
    self.assertEqual("(4, 5, 2)", str(tensor_shape.TensorShape([4, 5, 2])))

    self.assertEqual("(32, ?, 1, 9)",
                     str(tensor_shape.TensorShape([32, None, 1, 9])))
예제 #26
0
    def testStr(self):
        self.assertEqual("<unknown>", str(tensor_shape.unknown_shape()))
        self.assertEqual("(?,)", str(tensor_shape.unknown_shape(ndims=1)))
        self.assertEqual("(?, ?)", str(tensor_shape.unknown_shape(ndims=2)))
        self.assertEqual("(?, ?, ?)", str(tensor_shape.unknown_shape(ndims=3)))

        self.assertEqual("()", str(tensor_shape.scalar()))
        self.assertEqual("(7,)", str(tensor_shape.vector(7)))
        self.assertEqual("(3, 8)", str(tensor_shape.matrix(3, 8)))
        self.assertEqual("(4, 5, 2)", str(tensor_shape.TensorShape([4, 5, 2])))

        self.assertEqual("(32, ?, 1, 9)",
                         str(tensor_shape.TensorShape([32, None, 1, 9])))
예제 #27
0
def _ParseExampleShape(op):
    """Shape function for the ParseExample op."""
    input_shape = op.inputs[0].get_shape().with_rank(1)
    op.inputs[1].get_shape().with_rank(1)  # names
    num_sparse = op.get_attr("Nsparse")
    num_dense = op.get_attr("Ndense")
    dense_shapes = op.get_attr("dense_shapes")
    sparse_index_shapes = [tensor_shape.matrix(None, 2) for _ in range(num_sparse)]
    sparse_value_shapes = [tensor_shape.vector(None) for _ in range(num_sparse)]
    sparse_shape_shapes = [tensor_shape.vector(2) for _ in range(num_sparse)]
    assert num_dense == len(dense_shapes)
    dense_shapes = [input_shape.concatenate(dense_shape) for dense_shape in dense_shapes]
    return sparse_index_shapes + sparse_value_shapes + sparse_shape_shapes + dense_shapes
예제 #28
0
  def testBroadcast_many_dimensions(self):
    unknown = tensor_shape.unknown_shape()
    shape_0 = tensor_shape.scalar()
    shape_1 = tensor_shape.vector(1)
    shape_4 = tensor_shape.vector(4)
    shape_1x4 = tensor_shape.matrix(1, 4)
    shape_4x1 = tensor_shape.matrix(4, 1)
    shape_3x4 = tensor_shape.matrix(3, 4)
    shape_4x3 = tensor_shape.matrix(4, 3)

    # Tensors with same shape should have the same broadcast result.
    for shape in (
        shape_0, shape_1, shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3):
      self._assert_broadcast(expected=shape, shape1=shape, shape2=shape)

    # [] and [1] act like identity.
    for identity in (shape_0, shape_1):
      for shape in (shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3):
        self._assert_broadcast(expected=shape, shape1=identity, shape2=shape)

    # Unknown in, unknown out.
    for shape in (shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3):
      self._assert_broadcast(expected=unknown, shape1=shape, shape2=unknown)

    self._assert_broadcast(expected=shape_1x4, shape1=shape_4, shape2=shape_1x4)
    shape_4x4 = tensor_shape.matrix(4, 4)
    self._assert_broadcast(expected=shape_4x4, shape1=shape_4, shape2=shape_4x1)
    self._assert_broadcast(expected=shape_3x4, shape1=shape_4, shape2=shape_3x4)
    self._assert_incompatible_broadcast(shape1=shape_4, shape2=shape_4x3)
    self._assert_broadcast(
        expected=shape_4x4, shape1=shape_1x4, shape2=shape_4x1)
    self._assert_broadcast(
        expected=shape_3x4, shape1=shape_1x4, shape2=shape_3x4)
    self._assert_incompatible_broadcast(shape1=shape_1x4, shape2=shape_4x3)
    self._assert_incompatible_broadcast(shape1=shape_4x1, shape2=shape_3x4)
    self._assert_broadcast(
        expected=shape_4x3, shape1=shape_4x1, shape2=shape_4x3)
    self._assert_incompatible_broadcast(shape1=shape_3x4, shape2=shape_4x3)
예제 #29
0
def _ParseExampleShape(op):
  """Shape function for the ParseExample op."""
  input_shape = op.inputs[0].get_shape().with_rank(1)
  num_sparse = op.get_attr("Nsparse")
  num_dense = op.get_attr("Ndense")
  dense_shapes = op.get_attr("dense_shapes")
  sparse_index_shapes = [
      tensor_shape.matrix(None, 2) for _ in range(num_sparse)]
  sparse_value_shapes = [tensor_shape.vector(None) for _ in range(num_sparse)]
  sparse_shape_shapes = [tensor_shape.vector(2) for _ in range(num_sparse)]
  assert num_dense == len(dense_shapes)
  dense_shapes = [
      input_shape.concatenate((d.size for d in dense_shape.dim))
      for dense_shape in dense_shapes]
  return (sparse_index_shapes + sparse_value_shapes + sparse_shape_shapes +
          dense_shapes)
예제 #30
0
def _ParseExampleShape(op):  # pylint: disable=invalid-name
  """Shape function for the ParseExample op."""
  input_shape = op.inputs[0].get_shape().with_rank(1)
  op.inputs[1].get_shape().with_rank(1)  # names
  num_sparse = op.get_attr("Nsparse")
  num_dense = op.get_attr("Ndense")
  dense_shapes = op.get_attr("dense_shapes")
  sparse_index_shapes = [
      tensor_shape.matrix(None, 2) for _ in range(num_sparse)]
  sparse_value_shapes = [tensor_shape.vector(None) for _ in range(num_sparse)]
  sparse_shape_shapes = [tensor_shape.vector(2) for _ in range(num_sparse)]
  assert num_dense == len(dense_shapes)
  dense_shapes = [
      input_shape.concatenate(dense_shape)
      for dense_shape in dense_shapes]
  return (sparse_index_shapes + sparse_value_shapes + sparse_shape_shapes +
          dense_shapes)
예제 #31
0
 def testStr(self):
   self.assertEqual("<unknown>", str(tensor_shape.unknown_shape()))
   self.assertEqual(
       "(None,)",
       str(tensor_shape.unknown_shape(rank=1)).replace("?", "None"))
   self.assertEqual(
       "(None, None)",
       str(tensor_shape.unknown_shape(rank=2)).replace("?", "None"))
   self.assertEqual(
       "(None, None, None)",
       str(tensor_shape.unknown_shape(rank=3)).replace("?", "None"))
   self.assertEqual(
       "(32, None, 1, 9)",
       str(tensor_shape.TensorShape([32, None, 1, 9])).replace("?", "None"))
   self.assertEqual("()", str(tensor_shape.scalar()))
   self.assertEqual("(7,)", str(tensor_shape.vector(7)))
   self.assertEqual("(3, 8)", str(tensor_shape.matrix(3, 8)))
   self.assertEqual("(4, 5, 2)", str(tensor_shape.TensorShape([4, 5, 2])))
예제 #32
0
 def testStr(self):
     self.assertEqual("<unknown>", str(tensor_shape.unknown_shape()))
     self.assertEqual(
         "(None,)",
         str(tensor_shape.unknown_shape(rank=1)).replace("?", "None"))
     self.assertEqual(
         "(None, None)",
         str(tensor_shape.unknown_shape(rank=2)).replace("?", "None"))
     self.assertEqual(
         "(None, None, None)",
         str(tensor_shape.unknown_shape(rank=3)).replace("?", "None"))
     self.assertEqual(
         "(32, None, 1, 9)",
         str(tensor_shape.TensorShape([32, None, 1,
                                       9])).replace("?", "None"))
     self.assertEqual("()", str(tensor_shape.scalar()))
     self.assertEqual("(7,)", str(tensor_shape.vector(7)))
     self.assertEqual("(3, 8)", str(tensor_shape.matrix(3, 8)))
     self.assertEqual("(4, 5, 2)", str(tensor_shape.TensorShape([4, 5, 2])))
예제 #33
0
  def testBroadcast_unknown_dims(self):
    unknown = tensor_shape.unknown_shape()
    shape_0 = tensor_shape.scalar()
    shape_1 = tensor_shape.vector(1)
    # pylint: disable=invalid-name
    shape_U = tensor_shape.vector(None)
    shape_1xU = tensor_shape.matrix(1, None)
    shape_Ux1 = tensor_shape.matrix(None, 1)
    shape_4xU = tensor_shape.matrix(4, None)
    shape_Ux4 = tensor_shape.matrix(None, 4)
    # pylint: enable=invalid-name

    # Tensors with same shape should have the same broadcast result.
    for shape in (shape_U, shape_1xU, shape_Ux1, shape_4xU, shape_Ux4):
      self._assert_broadcast_with_unknown_dims(
          expected=shape, shape1=shape, shape2=shape)

    # [] and [1] act like identity.
    for identity in (shape_0, shape_1):
      for shape in (shape_U, shape_1xU, shape_Ux1, shape_4xU, shape_Ux4):
        self._assert_broadcast_with_unknown_dims(
            expected=shape, shape1=identity, shape2=shape)

    # Unknown in, unknown out.
    for shape in (shape_U, shape_1xU, shape_Ux1, shape_4xU, shape_Ux4):
      self._assert_broadcast_with_unknown_dims(
          expected=unknown, shape1=shape, shape2=unknown)

    self._assert_broadcast_with_unknown_dims(
        expected=shape_1xU, shape1=shape_U, shape2=shape_1xU)
    shape_UxU = tensor_shape.matrix(None, None)  # pylint: disable=invalid-name
    self._assert_broadcast_with_unknown_dims(
        expected=shape_UxU, shape1=shape_U, shape2=shape_Ux1)
    self._assert_broadcast_with_unknown_dims(
        expected=shape_4xU, shape1=shape_U, shape2=shape_4xU)
    self._assert_broadcast_with_unknown_dims(
        expected=shape_Ux4, shape1=shape_U, shape2=shape_Ux4)
    self._assert_broadcast_with_unknown_dims(
        expected=shape_UxU, shape1=shape_1xU, shape2=shape_Ux1)
    self._assert_broadcast_with_unknown_dims(
        expected=shape_4xU, shape1=shape_1xU, shape2=shape_4xU)
    self._assert_broadcast_with_unknown_dims(
        expected=shape_Ux4, shape1=shape_1xU, shape2=shape_Ux4)
    self._assert_broadcast_with_unknown_dims(
        expected=shape_4xU, shape1=shape_Ux1, shape2=shape_4xU)
    self._assert_broadcast_with_unknown_dims(
        expected=shape_Ux4, shape1=shape_Ux1, shape2=shape_Ux4)
    shape_4x4 = tensor_shape.matrix(4, 4)
    self._assert_broadcast_with_unknown_dims(
        expected=shape_4x4, shape1=shape_4xU, shape2=shape_Ux4)
예제 #34
0
  def testBroadcast_unknown_dims(self):
    unknown = tensor_shape.unknown_shape()
    shape_0 = tensor_shape.scalar()
    shape_1 = tensor_shape.vector(1)
    # pylint: disable=invalid-name
    shape_U = tensor_shape.vector(None)
    shape_1xU = tensor_shape.matrix(1, None)
    shape_Ux1 = tensor_shape.matrix(None, 1)
    shape_4xU = tensor_shape.matrix(4, None)
    shape_Ux4 = tensor_shape.matrix(None, 4)
    # pylint: enable=invalid-name

    # Tensors with same shape should have the same broadcast result.
    for shape in (shape_U, shape_1xU, shape_Ux1, shape_4xU, shape_Ux4):
      self._assert_broadcast_with_unknown_dims(
          expected=shape, shape1=shape, shape2=shape)

    # [] and [1] act like identity.
    for identity in (shape_0, shape_1):
      for shape in (shape_U, shape_1xU, shape_Ux1, shape_4xU, shape_Ux4):
        self._assert_broadcast_with_unknown_dims(
            expected=shape, shape1=identity, shape2=shape)

    # Unknown in, unknown out.
    for shape in (shape_U, shape_1xU, shape_Ux1, shape_4xU, shape_Ux4):
      self._assert_broadcast_with_unknown_dims(
          expected=unknown, shape1=shape, shape2=unknown)

    self._assert_broadcast_with_unknown_dims(
        expected=shape_1xU, shape1=shape_U, shape2=shape_1xU)
    shape_UxU = tensor_shape.matrix(None, None)  # pylint: disable=invalid-name
    self._assert_broadcast_with_unknown_dims(
        expected=shape_UxU, shape1=shape_U, shape2=shape_Ux1)
    self._assert_broadcast_with_unknown_dims(
        expected=shape_4xU, shape1=shape_U, shape2=shape_4xU)
    self._assert_broadcast_with_unknown_dims(
        expected=shape_Ux4, shape1=shape_U, shape2=shape_Ux4)
    self._assert_broadcast_with_unknown_dims(
        expected=shape_UxU, shape1=shape_1xU, shape2=shape_Ux1)
    self._assert_broadcast_with_unknown_dims(
        expected=shape_4xU, shape1=shape_1xU, shape2=shape_4xU)
    self._assert_broadcast_with_unknown_dims(
        expected=shape_Ux4, shape1=shape_1xU, shape2=shape_Ux4)
    self._assert_broadcast_with_unknown_dims(
        expected=shape_4xU, shape1=shape_Ux1, shape2=shape_4xU)
    self._assert_broadcast_with_unknown_dims(
        expected=shape_Ux4, shape1=shape_Ux1, shape2=shape_Ux4)
    shape_4x4 = tensor_shape.matrix(4, 4)
    self._assert_broadcast_with_unknown_dims(
        expected=shape_4x4, shape1=shape_4xU, shape2=shape_Ux4)
예제 #35
0
def _PadShape(op):
  """Shape function for the Pad op.

  This op has two inputs:

  * input: A rank-N tensor.
  * paddings: An N-by-2 matrix, in which the i^th row contains the
    number of padding elements to add before and after `input` in the
    i^th dimension.

  It has one output, which has the same rank as input, and additional
  elements according to the values in paddings.

  Args:
    op: A Pad Operation.

  Returns:
    A single-element list containing the shape of the output.

  Raises:
    ValueError: If the input shapes are incompatible.
  """
  paddings_shape = op.inputs[1].get_shape().with_rank(2)
  input_shape = op.inputs[0].get_shape()
  if input_shape.ndims == 0 and paddings_shape[0].value == 1:
    # TODO(irving): Remove once !kAllowLegacyScalars.
    input_shape = tensor_shape.TensorShape([1])
  else:
    input_shape = input_shape.with_rank(paddings_shape[0].value)
  paddings_shape = paddings_shape.merge_with(
      tensor_shape.matrix(input_shape.ndims, 2))
  paddings = tensor_util.ConstantValue(op.inputs[1])
  if paddings is None:
    return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
  else:
    output_dims = []
    for i, dim in enumerate(input_shape.dims):
      if paddings[i, 0] < 0 or paddings[i, 1] < 0:
        raise ValueError("paddings must be non-negative")
      output_dims.append(dim + paddings[i, 0] + paddings[i, 1])
    return [tensor_shape.TensorShape(output_dims)]
예제 #36
0
def _PadShape(op):
    """Shape function for the Pad op.

  This op has two inputs:

  * input: A rank-N tensor.
  * paddings: An N-by-2 matrix, in which the i^th row contains the
    number of padding elements to add before and after `input` in the
    i^th dimension.

  It has one output, which has the same rank as input, and additional
  elements according to the values in paddings.

  Args:
    op: A Pad Operation.

  Returns:
    A single-element list containing the shape of the output.

  Raises:
    ValueError: If the input shapes are incompatible.
  """
    paddings_shape = op.inputs[1].get_shape().with_rank(2)
    input_shape = op.inputs[0].get_shape()
    if input_shape.ndims == 0 and paddings_shape[0].value == 1:
        # TODO(irving): Remove once !kAllowLegacyScalars.
        input_shape = tensor_shape.TensorShape([1])
    else:
        input_shape = input_shape.with_rank(paddings_shape[0].value)
    paddings_shape = paddings_shape.merge_with(
        tensor_shape.matrix(input_shape.ndims, 2))
    paddings = tensor_util.ConstantValue(op.inputs[1])
    if paddings is None:
        return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
    else:
        output_dims = []
        for i, dim in enumerate(input_shape.dims):
            if paddings[i, 0] < 0 or paddings[i, 1] < 0:
                raise ValueError("paddings must be non-negative")
            output_dims.append(dim + paddings[i, 0] + paddings[i, 1])
        return [tensor_shape.TensorShape(output_dims)]
예제 #37
0
파일: model.py 프로젝트: nguyenlab/SentSum
def reverse_seq(input_seq, lengths):
    if lengths is None:
        return list(reversed(input_seq))

    input_shape = tensor_shape.matrix(None, None)
    for input_ in input_seq:
        input_shape.merge_with(input_.get_shape())
        input_.set_shape(input_shape)

    # Join into (time, batch_size, depth)
    s_joined = array_ops.pack(input_seq)

    # TODO(schuster, ebrevdo): Remove cast when reverse_sequence takes int32
    if lengths is not None:
        lengths = math_ops.to_int64(lengths)

    # Reverse along dimension 0
    s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1)
    # Split again into list
    result = array_ops.unpack(s_reversed)
    for r in result:
        r.set_shape(input_shape)
    return result
예제 #38
0
def _MultinomialShape(op):  # pylint: disable=invalid-name
    logits_shape = op.inputs[0].get_shape().with_rank(2)
    batch_size = logits_shape[0]
    num_samples_or_none = tensor_util.constant_value(op.inputs[1])
    return [tensor_shape.matrix(batch_size, num_samples_or_none)]
def _SparseFeatureCrossShape(unused_op):  # pylint: disable=invalid-name
    return [
        tensor_shape.matrix(None, 2),
        tensor_shape.vector(None),
        tensor_shape.vector(2)
    ]
예제 #40
0
class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
                    test_util.TensorFlowTestCase):

    # pylint: disable=g-long-lambda,protected-access
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0), tensor_spec.TensorSpec,
         [dtypes.float32], [[]]),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0),
         tensor_array_ops.TensorArraySpec, [dtypes.variant], [[]]),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
         sparse_tensor.SparseTensorSpec, [dtypes.variant], [None]),
        ("RaggedTensor",
         lambda: ragged_factory_ops.constant([[1, 2], [], [4]]),
         ragged_tensor.RaggedTensorSpec, [dtypes.variant], [None]),
        ("Nested_0", lambda:
         (constant_op.constant(37.0), constant_op.constant([1, 2, 3])), tuple,
         [dtypes.float32, dtypes.int32], [[], [3]]),
        ("Nested_1", lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, dict, [dtypes.float32, dtypes.int32], [[], [3]]),
        ("Nested_2", lambda: {
            "a":
            constant_op.constant(37.0),
            "b":
            (sparse_tensor.
             SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
             sparse_tensor.SparseTensor(
                 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
        }, dict, [dtypes.float32, dtypes.variant, dtypes.variant], [[], None,
                                                                    None]),
    )
    def testFlatStructure(self, value_fn, expected_structure, expected_types,
                          expected_shapes):
        value = value_fn()
        s = structure.type_spec_from_value(value)
        self.assertIsInstance(s, expected_structure)
        flat_types = structure.get_flat_tensor_types(s)
        self.assertEqual(expected_types, flat_types)
        flat_shapes = structure.get_flat_tensor_shapes(s)
        self.assertLen(flat_shapes, len(expected_shapes))
        for expected, actual in zip(expected_shapes, flat_shapes):
            if expected is None:
                self.assertEqual(actual.ndims, None)
            else:
                self.assertEqual(actual.as_list(), expected)

    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0), lambda: [
            constant_op.constant(38.0),
            array_ops.placeholder(dtypes.float32),
            variables.Variable(100.0), 42.0,
            np.array(42.0, dtype=np.float32)
        ],
         lambda: [constant_op.constant([1.0, 2.0]),
                  constant_op.constant(37)]),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0), lambda: [
                tensor_array_ops.TensorArray(
                    dtype=dtypes.float32, element_shape=(3, ), size=0),
                tensor_array_ops.TensorArray(
                    dtype=dtypes.float32, element_shape=(3, ), size=10)
            ], lambda: [
                tensor_array_ops.TensorArray(
                    dtype=dtypes.int32, element_shape=(3, ), size=0),
                tensor_array_ops.TensorArray(
                    dtype=dtypes.float32, element_shape=(), size=0)
            ]),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: [
                sparse_tensor.SparseTensor(indices=[[1, 1], [3, 4]],
                                           values=[10, -1],
                                           dense_shape=[4, 5]),
                sparse_tensor.SparseTensorValue(indices=[[1, 1], [3, 4]],
                                                values=[10, -1],
                                                dense_shape=[4, 5]),
                array_ops.sparse_placeholder(dtype=dtypes.int32),
                array_ops.sparse_placeholder(dtype=dtypes.int32,
                                             shape=[None, None])
            ], lambda: [
                constant_op.constant(37, shape=[4, 5]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1], dense_shape=[5, 6]),
                array_ops.sparse_placeholder(dtype=dtypes.int32,
                                             shape=[None, None, None]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5])
            ]),
        ("RaggedTensor",
         lambda: ragged_factory_ops.constant([[1, 2], [], [3]]), lambda: [
             ragged_factory_ops.constant([[1, 2], [3, 4], []]),
             ragged_factory_ops.constant([[1], [2, 3, 4], [5]]),
         ], lambda: [
             ragged_factory_ops.constant(1),
             ragged_factory_ops.constant([1, 2]),
             ragged_factory_ops.constant([[1], [2]]),
             ragged_factory_ops.constant([["a", "b"]]),
         ]),
        ("Nested", lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, lambda: [{
            "a": constant_op.constant(15.0),
            "b": constant_op.constant([4, 5, 6])
        }], lambda: [{
            "a": constant_op.constant(15.0),
            "b": constant_op.constant([4, 5, 6, 7])
        }, {
            "a": constant_op.constant(15),
            "b": constant_op.constant([4, 5, 6])
        }, {
            "a":
            constant_op.constant(15),
            "b":
            sparse_tensor.SparseTensor(
                indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
        }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
    )
    @test_util.run_deprecated_v1
    def testIsCompatibleWithStructure(self, original_value_fn,
                                      compatible_values_fn,
                                      incompatible_values_fn):
        original_value = original_value_fn()
        compatible_values = compatible_values_fn()
        incompatible_values = incompatible_values_fn()
        s = structure.type_spec_from_value(original_value)
        for compatible_value in compatible_values:
            self.assertTrue(
                structure.are_compatible(
                    s, structure.type_spec_from_value(compatible_value)))
        for incompatible_value in incompatible_values:
            self.assertFalse(
                structure.are_compatible(
                    s, structure.type_spec_from_value(incompatible_value)))

    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         lambda: constant_op.constant(42.0),
         lambda: constant_op.constant([5])),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0),
         lambda: tensor_array_ops.TensorArray(
             dtype=dtypes.float32, element_shape=(3, ), size=0),
         lambda: tensor_array_ops.TensorArray(
             dtype=dtypes.int32, element_shape=(), size=0)),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[1, 2]], values=[42], dense_shape=[4, 5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[3]], values=[-1], dense_shape=[5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[3, 4]], values=[1.0], dense_shape=[4, 5])),
        ("RaggedTensor",
         lambda: ragged_factory_ops.constant([[[1, 2]], [[3]]]),
         lambda: ragged_factory_ops.constant([[[5]], [[8], [3, 2]]]), lambda:
         ragged_factory_ops.constant([[[1]], [[2], [3]]], ragged_rank=1),
         lambda: ragged_factory_ops.constant([[[1.0, 2.0]], [[3.0]]]),
         lambda: ragged_factory_ops.constant([[[1]], [[2]], [[3]]])),
        ("Nested", lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, lambda: {
            "a": constant_op.constant(42.0),
            "b": constant_op.constant([4, 5, 6])
        }, lambda: {
            "a": constant_op.constant([1, 2, 3]),
            "b": constant_op.constant(37.0)
        }),
    )  # pyformat: disable
    def testStructureFromValueEquality(self, value1_fn, value2_fn,
                                       *not_equal_value_fns):
        # pylint: disable=g-generic-assert
        s1 = structure.type_spec_from_value(value1_fn())
        s2 = structure.type_spec_from_value(value2_fn())
        self.assertEqual(s1, s1)  # check __eq__ operator.
        self.assertEqual(s1, s2)  # check __eq__ operator.
        self.assertFalse(s1 != s1)  # check __ne__ operator.
        self.assertFalse(s1 != s2)  # check __ne__ operator.
        for c1, c2 in zip(nest.flatten(s1), nest.flatten(s2)):
            self.assertEqual(hash(c1), hash(c1))
            self.assertEqual(hash(c1), hash(c2))
        for value_fn in not_equal_value_fns:
            s3 = structure.type_spec_from_value(value_fn())
            self.assertNotEqual(s1, s3)  # check __ne__ operator.
            self.assertNotEqual(s2, s3)  # check __ne__ operator.
            self.assertFalse(s1 == s3)  # check __eq_ operator.
            self.assertFalse(s2 == s3)  # check __eq_ operator.

    @parameterized.named_parameters(
        ("RaggedTensor_RaggedRank",
         structure.RaggedTensorStructure(dtypes.int32, None, 1),
         structure.RaggedTensorStructure(dtypes.int32, None, 2)),
        ("RaggedTensor_Shape",
         structure.RaggedTensorStructure(dtypes.int32, [3, None], 1),
         structure.RaggedTensorStructure(dtypes.int32, [5, None], 1)),
        ("RaggedTensor_DType",
         structure.RaggedTensorStructure(dtypes.int32, None, 1),
         structure.RaggedTensorStructure(dtypes.float32, None, 1)),
    )
    def testRaggedStructureInequality(self, s1, s2):
        # pylint: disable=g-generic-assert
        self.assertNotEqual(s1, s2)  # check __ne__ operator.
        self.assertFalse(s1 == s2)  # check __eq__ operator.

    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         lambda: constant_op.constant(42.0),
         lambda: constant_op.constant([5])),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0),
         lambda: tensor_array_ops.TensorArray(
             dtype=dtypes.float32, element_shape=(3, ), size=0),
         lambda: tensor_array_ops.TensorArray(
             dtype=dtypes.int32, element_shape=(), size=0)),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[1, 2]], values=[42], dense_shape=[4, 5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[3]], values=[-1], dense_shape=[5])),
        ("Nested", lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, lambda: {
            "a": constant_op.constant(42.0),
            "b": constant_op.constant([4, 5, 6])
        }, lambda: {
            "a": constant_op.constant([1, 2, 3]),
            "b": constant_op.constant(37.0)
        }),
    )
    def testHash(self, value1_fn, value2_fn, value3_fn):
        s1 = structure.type_spec_from_value(value1_fn())
        s2 = structure.type_spec_from_value(value2_fn())
        s3 = structure.type_spec_from_value(value3_fn())
        for c1, c2, c3 in zip(nest.flatten(s1), nest.flatten(s2),
                              nest.flatten(s3)):
            self.assertEqual(hash(c1), hash(c1))
            self.assertEqual(hash(c1), hash(c2))
            self.assertNotEqual(hash(c1), hash(c3))
            self.assertNotEqual(hash(c2), hash(c3))

    @parameterized.named_parameters(
        (
            "Tensor",
            lambda: constant_op.constant(37.0),
        ),
        (
            "SparseTensor",
            lambda: sparse_tensor.SparseTensor(
                indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
        ),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)),
        (
            "RaggedTensor",
            lambda: ragged_factory_ops.constant([[1, 2], [], [3]]),
        ),
        (
            "Nested_0",
            lambda: {
                "a": constant_op.constant(37.0),
                "b": constant_op.constant([1, 2, 3])
            },
        ),
        (
            "Nested_1",
            lambda: {
                "a":
                constant_op.constant(37.0),
                "b": (sparse_tensor.SparseTensor(
                    indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
                      sparse_tensor.SparseTensor(
                          indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
            },
        ),
    )
    def testRoundTripConversion(self, value_fn):
        value = value_fn()
        s = structure.type_spec_from_value(value)

        def maybe_stack_ta(v):
            if isinstance(v, tensor_array_ops.TensorArray):
                return v.stack()
            else:
                return v

        before = self.evaluate(maybe_stack_ta(value))
        after = self.evaluate(
            maybe_stack_ta(
                structure.from_tensor_list(s,
                                           structure.to_tensor_list(s,
                                                                    value))))

        flat_before = nest.flatten(before)
        flat_after = nest.flatten(after)
        for b, a in zip(flat_before, flat_after):
            if isinstance(b, sparse_tensor.SparseTensorValue):
                self.assertAllEqual(b.indices, a.indices)
                self.assertAllEqual(b.values, a.values)
                self.assertAllEqual(b.dense_shape, a.dense_shape)
            elif isinstance(b, (ragged_tensor.RaggedTensor,
                                ragged_tensor_value.RaggedTensorValue)):
                self.assertAllEqual(b, a)
            else:
                self.assertAllEqual(b, a)

    # pylint: enable=g-long-lambda

    def preserveStaticShape(self):
        rt = ragged_factory_ops.constant([[1, 2], [], [3]])
        rt_s = structure.type_spec_from_value(rt)
        rt_after = structure.from_tensor_list(
            rt_s, structure.to_tensor_list(rt_s, rt))
        self.assertEqual(rt_after.row_splits.shape.as_list(),
                         rt.row_splits.shape.as_list())
        self.assertEqual(rt_after.values.shape.as_list(), [None])

        st = sparse_tensor.SparseTensor(indices=[[3, 4]],
                                        values=[-1],
                                        dense_shape=[4, 5])
        st_s = structure.type_spec_from_value(st)
        st_after = structure.from_tensor_list(
            st_s, structure.to_tensor_list(st_s, st))
        self.assertEqual(st_after.indices.shape.as_list(), [None, 2])
        self.assertEqual(st_after.values.shape.as_list(), [None])
        self.assertEqual(st_after.dense_shape.shape.as_list(),
                         st.dense_shape.shape.as_list())

    def testIncompatibleStructure(self):
        # Define three mutually incompatible values/structures, and assert that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructre a flattened value with an
        #    incompatible structure fails.
        value_tensor = constant_op.constant(42.0)
        s_tensor = structure.type_spec_from_value(value_tensor)
        flat_tensor = structure.to_tensor_list(s_tensor, value_tensor)

        value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]],
                                                         values=[1],
                                                         dense_shape=[1, 1])
        s_sparse_tensor = structure.type_spec_from_value(value_sparse_tensor)
        flat_sparse_tensor = structure.to_tensor_list(s_sparse_tensor,
                                                      value_sparse_tensor)

        value_nest = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_nest = structure.type_spec_from_value(value_nest)
        flat_nest = structure.to_tensor_list(s_nest, value_nest)

        with self.assertRaisesRegexp(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*float32.* and shape \(\)"):
            structure.to_tensor_list(s_tensor, value_sparse_tensor)
        with self.assertRaisesRegexp(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_tensor, value_nest)

        with self.assertRaisesRegexp(
                TypeError, "Neither a SparseTensor nor SparseTensorValue"):
            structure.to_tensor_list(s_sparse_tensor, value_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_sparse_tensor, value_nest)

        with self.assertRaisesRegexp(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_nest, value_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_nest, value_sparse_tensor)

        with self.assertRaisesRegexp(ValueError, r"Incompatible input:"):
            structure.from_tensor_list(s_tensor, flat_sparse_tensor)

        with self.assertRaisesRegexp(ValueError,
                                     "Expected 1 tensors but got 2."):
            structure.from_tensor_list(s_tensor, flat_nest)

        with self.assertRaisesRegexp(ValueError, "Incompatible input: "):
            structure.from_tensor_list(s_sparse_tensor, flat_tensor)

        with self.assertRaisesRegexp(ValueError,
                                     "Expected 1 tensors but got 2."):
            structure.from_tensor_list(s_sparse_tensor, flat_nest)

        with self.assertRaisesRegexp(ValueError,
                                     "Expected 2 tensors but got 1."):
            structure.from_tensor_list(s_nest, flat_tensor)

        with self.assertRaisesRegexp(ValueError,
                                     "Expected 2 tensors but got 1."):
            structure.from_tensor_list(s_nest, flat_sparse_tensor)

    def testIncompatibleNestedStructure(self):
        # Define three mutually incompatible nested values/structures, and assert
        # that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructure a flattened value with an
        #    incompatible structure fails.

        value_0 = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_0 = structure.type_spec_from_value(value_0)
        flat_s_0 = structure.to_tensor_list(s_0, value_0)

        # `value_1` has compatible nested structure with `value_0`, but different
        # classes.
        value_1 = {
            "a":
            constant_op.constant(37.0),
            "b":
            sparse_tensor.SparseTensor(indices=[[0, 0]],
                                       values=[1],
                                       dense_shape=[1, 1])
        }
        s_1 = structure.type_spec_from_value(value_1)
        flat_s_1 = structure.to_tensor_list(s_1, value_1)

        # `value_2` has incompatible nested structure with `value_0` and `value_1`.
        value_2 = {
            "a":
            constant_op.constant(37.0),
            "b": (sparse_tensor.SparseTensor(indices=[[0, 0]],
                                             values=[1],
                                             dense_shape=[1, 1]),
                  sparse_tensor.SparseTensor(indices=[[3, 4]],
                                             values=[-1],
                                             dense_shape=[4, 5]))
        }
        s_2 = structure.type_spec_from_value(value_2)
        flat_s_2 = structure.to_tensor_list(s_2, value_2)

        with self.assertRaisesRegexp(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*int32.* and shape \(3,\)"):
            structure.to_tensor_list(s_0, value_1)

        with self.assertRaisesRegexp(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_0, value_2)

        with self.assertRaisesRegexp(
                TypeError, "Neither a SparseTensor nor SparseTensorValue"):
            structure.to_tensor_list(s_1, value_0)

        with self.assertRaisesRegexp(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_1, value_2)

        # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
        # needs to account for "a" coming before or after "b". It might be worth
        # adding a deterministic repr for these error messages (among other
        # improvements).
        with self.assertRaisesRegexp(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_2, value_0)

        with self.assertRaisesRegexp(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_2, value_1)

        with self.assertRaisesRegexp(ValueError, r"Incompatible input:"):
            structure.from_tensor_list(s_0, flat_s_1)

        with self.assertRaisesRegexp(ValueError,
                                     "Expected 2 tensors but got 3."):
            structure.from_tensor_list(s_0, flat_s_2)

        with self.assertRaisesRegexp(ValueError, "Incompatible input: "):
            structure.from_tensor_list(s_1, flat_s_0)

        with self.assertRaisesRegexp(ValueError,
                                     "Expected 2 tensors but got 3."):
            structure.from_tensor_list(s_1, flat_s_2)

        with self.assertRaisesRegexp(ValueError,
                                     "Expected 3 tensors but got 2."):
            structure.from_tensor_list(s_2, flat_s_0)

        with self.assertRaisesRegexp(ValueError,
                                     "Expected 3 tensors but got 2."):
            structure.from_tensor_list(s_2, flat_s_1)

    @parameterized.named_parameters(
        ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor,
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", dtypes.int32, tensor_shape.matrix(
            2, 2), sparse_tensor.SparseTensor,
         structure.SparseTensorStructure(dtypes.int32, [2, 2])),
        ("TensorArray_0", dtypes.int32,
         tensor_shape.as_shape([None, True, 2, 2
                                ]), tensor_array_ops.TensorArray,
         structure.TensorArrayStructure(
             dtypes.int32, [2, 2], dynamic_size=None, infer_shape=True)),
        ("TensorArray_1", dtypes.int32,
         tensor_shape.as_shape([True, None, 2, 2
                                ]), tensor_array_ops.TensorArray,
         structure.TensorArrayStructure(
             dtypes.int32, [2, 2], dynamic_size=True, infer_shape=None)),
        ("TensorArray_2", dtypes.int32,
         tensor_shape.as_shape([True, False, 2, 2
                                ]), tensor_array_ops.TensorArray,
         structure.TensorArrayStructure(
             dtypes.int32, [2, 2], dynamic_size=True, infer_shape=False)),
        ("RaggedTensor", dtypes.int32, tensor_shape.matrix(2, None),
         structure.RaggedTensorStructure(dtypes.int32, [2, None], 1),
         structure.RaggedTensorStructure(dtypes.int32, [2, None], 1)),
        ("Nested", {
            "a": dtypes.float32,
            "b": (dtypes.int32, dtypes.string)
        }, {
            "a": tensor_shape.scalar(),
            "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar())
        }, {
            "a": ops.Tensor,
            "b": (sparse_tensor.SparseTensor, ops.Tensor)
        }, {
            "a":
            structure.TensorStructure(dtypes.float32, []),
            "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
                  structure.TensorStructure(dtypes.string, []))
        }),
    )
    def testConvertLegacyStructure(self, output_types, output_shapes,
                                   output_classes, expected_structure):
        actual_structure = structure.convert_legacy_structure(
            output_types, output_shapes, output_classes)
        self.assertEqual(actual_structure, expected_structure)

    def testNestedNestedStructure(self):
        s = (structure.TensorStructure(dtypes.int64, []),
             (structure.TensorStructure(dtypes.float32, []),
              structure.TensorStructure(dtypes.string, [])))

        int64_t = constant_op.constant(37, dtype=dtypes.int64)
        float32_t = constant_op.constant(42.0)
        string_t = constant_op.constant("Foo")

        nested_tensors = (int64_t, (float32_t, string_t))

        tensor_list = structure.to_tensor_list(s, nested_tensors)
        for expected, actual in zip([int64_t, float32_t, string_t],
                                    tensor_list):
            self.assertIs(expected, actual)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = structure.from_tensor_list(s, tensor_list)
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = (structure.from_compatible_tensor_list(
              s, tensor_list))
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)

    @parameterized.named_parameters(
        ("Tensor", structure.TensorStructure(dtypes.float32, []), 32,
         structure.TensorStructure(dtypes.float32, [32])),
        ("TensorUnknown", structure.TensorStructure(dtypes.float32, []), None,
         structure.TensorStructure(dtypes.float32, [None])),
        ("SparseTensor", structure.SparseTensorStructure(
            dtypes.float32, [None]), 32,
         structure.SparseTensorStructure(dtypes.float32, [32, None])),
        ("SparseTensorUnknown",
         structure.SparseTensorStructure(dtypes.float32, [4]), None,
         structure.SparseTensorStructure(dtypes.float32, [None, 4])),
        ("RaggedTensor",
         structure.RaggedTensorStructure(dtypes.float32, [2, None], 1), 32,
         structure.RaggedTensorStructure(dtypes.float32, [32, 2, None], 2)),
        ("RaggedTensorUnknown",
         structure.RaggedTensorStructure(dtypes.float32, [4, None], 1), None,
         structure.RaggedTensorStructure(dtypes.float32, [None, 4, None], 2)),
        ("Nested", {
            "a":
            structure.TensorStructure(dtypes.float32, []),
            "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
                  structure.TensorStructure(dtypes.string, []))
        }, 128, {
            "a":
            structure.TensorStructure(dtypes.float32, [128]),
            "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]),
                  structure.TensorStructure(dtypes.string, [128]))
        }),
    )
    def testBatch(self, element_structure, batch_size,
                  expected_batched_structure):
        batched_structure = nest.map_structure(
            lambda component_spec: component_spec._batch(batch_size),
            element_structure)
        self.assertEqual(batched_structure, expected_batched_structure)

    @parameterized.named_parameters(
        ("Tensor", structure.TensorStructure(dtypes.float32, [32]),
         structure.TensorStructure(dtypes.float32, [])),
        ("TensorUnknown", structure.TensorStructure(dtypes.float32, [None]),
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor",
         structure.SparseTensorStructure(dtypes.float32, [32, None]),
         structure.SparseTensorStructure(dtypes.float32, [None])),
        ("SparseTensorUnknown",
         structure.SparseTensorStructure(dtypes.float32, [None, 4]),
         structure.SparseTensorStructure(dtypes.float32, [4])),
        ("RaggedTensor",
         structure.RaggedTensorStructure(dtypes.float32, [32, None, None], 2),
         structure.RaggedTensorStructure(dtypes.float32, [None, None], 1)),
        ("RaggedTensorUnknown",
         structure.RaggedTensorStructure(dtypes.float32, [None, None, None],
                                         2),
         structure.RaggedTensorStructure(dtypes.float32, [None, None], 1)),
        ("Nested", {
            "a":
            structure.TensorStructure(dtypes.float32, [128]),
            "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]),
                  structure.TensorStructure(dtypes.string, [None]))
        }, {
            "a":
            structure.TensorStructure(dtypes.float32, []),
            "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
                  structure.TensorStructure(dtypes.string, []))
        }),
    )
    def testUnbatch(self, element_structure, expected_unbatched_structure):
        unbatched_structure = nest.map_structure(
            lambda component_spec: component_spec._unbatch(),
            element_structure)
        self.assertEqual(unbatched_structure, expected_unbatched_structure)

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
         lambda: constant_op.constant([1.0, 2.0])),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[0]], values=[13], dense_shape=[2])),
        ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1]], [[2]]]),
         lambda: ragged_factory_ops.constant([[1]])),
        ("Nest", lambda:
         (constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
          sparse_tensor.SparseTensor(
              indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2])),
         lambda: (constant_op.constant([1.0, 2.0]),
                  sparse_tensor.SparseTensor(
                      indices=[[0]], values=[13], dense_shape=[2]))),
    )
    def testToBatchedTensorList(self, value_fn, element_0_fn):
        batched_value = value_fn()
        s = structure.type_spec_from_value(batched_value)
        batched_tensor_list = structure.to_batched_tensor_list(
            s, batched_value)

        # The batch dimension is 2 for all of the test cases.
        # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT
        # tensors in which we store sparse tensors.
        for t in batched_tensor_list:
            if t.dtype != dtypes.variant:
                self.assertEqual(2, self.evaluate(array_ops.shape(t)[0]))

        # Test that the 0th element from the unbatched tensor is equal to the
        # expected value.
        expected_element_0 = self.evaluate(element_0_fn())
        unbatched_s = nest.map_structure(
            lambda component_spec: component_spec._unbatch(), s)
        actual_element_0 = structure.from_tensor_list(
            unbatched_s, [t[0] for t in batched_tensor_list])

        for expected, actual in zip(nest.flatten(expected_element_0),
                                    nest.flatten(actual_element_0)):
            self.assertValuesEqual(expected, actual)
예제 #41
0
 def output_shapes(self):
   num_elements = tensor_shape.Dimension(None)
   return (tensor_shape.matrix(num_elements, self._row_shape.shape[0] + 1),
           tensor_shape.vector(num_elements),
           tensor_shape.vector(self._row_shape.shape[0] + 1))
예제 #42
0
 def _to_legacy_output_shapes(self):
     # Sneak the dynamic_size and infer_shape values into the legacy shape.
     return (tensor_shape.matrix(self._dynamic_size,
                                 self._infer_shape).concatenate(
                                     self._element_shape))
예제 #43
0
class StructureTest(test.TestCase, parameterized.TestCase):

    # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they
    # will be executed before the (eager- or graph-mode) test environment has been
    # set up.
    # pylint: disable=g-long-lambda,protected-access
    @parameterized.parameters(
        (lambda: constant_op.constant(37.0), structure.TensorStructure,
         [dtypes.float32], [[]]),
        (lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
         structure.SparseTensorStructure, [dtypes.variant], [None]),
        (lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
         structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
        (lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]
                                                                       ]),
        (lambda: {
            "a":
            constant_op.constant(37.0),
            "b":
            (sparse_tensor.
             SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
             sparse_tensor.SparseTensor(
                 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
        }, structure.NestedStructure,
         [dtypes.float32, dtypes.variant, dtypes.variant], [[], None, None]))
    def testFlatStructure(self, value_fn, expected_structure, expected_types,
                          expected_shapes):
        value = value_fn()
        s = structure.Structure.from_value(value)
        self.assertIsInstance(s, expected_structure)
        self.assertEqual(expected_types, s._flat_types)
        for expected, actual in zip(expected_shapes, s._flat_shapes):
            self.assertTrue(actual.is_compatible_with(expected))
            self.assertTrue(
                tensor_shape.as_shape(expected).is_compatible_with(actual))

    @parameterized.parameters(
        (lambda: constant_op.constant(37.0), lambda: [
            constant_op.constant(38.0),
            array_ops.placeholder(dtypes.float32),
            variables.Variable(100.0), 42.0,
            np.array(42.0, dtype=np.float32)
        ],
         lambda: [constant_op.constant([1.0, 2.0]),
                  constant_op.constant(37)]),
        (lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: [
                sparse_tensor.SparseTensor(indices=[[1, 1], [3, 4]],
                                           values=[10, -1],
                                           dense_shape=[4, 5]),
                sparse_tensor.SparseTensorValue(indices=[[1, 1], [3, 4]],
                                                values=[10, -1],
                                                dense_shape=[4, 5]),
                array_ops.sparse_placeholder(dtype=dtypes.int32),
                array_ops.sparse_placeholder(dtype=dtypes.int32,
                                             shape=[None, None])
            ], lambda: [
                constant_op.constant(37, shape=[4, 5]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1], dense_shape=[5, 6]),
                array_ops.sparse_placeholder(dtype=dtypes.int32,
                                             shape=[None, None, None]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5])
            ]),
        (lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, lambda: [{
            "a": constant_op.constant(15.0),
            "b": constant_op.constant([4, 5, 6])
        }], lambda: [{
            "a": constant_op.constant(15.0),
            "b": constant_op.constant([4, 5, 6, 7])
        }, {
            "a": constant_op.constant(15),
            "b": constant_op.constant([4, 5, 6])
        }, {
            "a":
            constant_op.constant(15),
            "b":
            sparse_tensor.SparseTensor(
                indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
        }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
    )
    def testIsCompatibleWithStructure(self, original_value_fn,
                                      compatible_values_fn,
                                      incompatible_values_fn):
        original_value = original_value_fn()
        compatible_values = compatible_values_fn()
        incompatible_values = incompatible_values_fn()
        s = structure.Structure.from_value(original_value)
        for compatible_value in compatible_values:
            self.assertTrue(
                s.is_compatible_with(
                    structure.Structure.from_value(compatible_value)))
        for incompatible_value in incompatible_values:
            self.assertFalse(
                s.is_compatible_with(
                    structure.Structure.from_value(incompatible_value)))

    @parameterized.parameters(
        (lambda: constant_op.constant(37.0), ),
        (lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), ),
        (lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, ),
        (lambda: {
            "a":
            constant_op.constant(37.0),
            "b":
            (sparse_tensor.
             SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
             sparse_tensor.SparseTensor(
                 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
        }, ),
    )
    def testRoundTripConversion(self, value_fn):
        value = value_fn()
        s = structure.Structure.from_value(value)
        before = self.evaluate(value)
        after = self.evaluate(s._from_tensor_list(s._to_tensor_list(value)))

        flat_before = nest.flatten(before)
        flat_after = nest.flatten(after)
        for b, a in zip(flat_before, flat_after):
            if isinstance(b, sparse_tensor.SparseTensorValue):
                self.assertAllEqual(b.indices, a.indices)
                self.assertAllEqual(b.values, a.values)
                self.assertAllEqual(b.dense_shape, a.dense_shape)
            else:
                self.assertAllEqual(b, a)

    # pylint: enable=g-long-lambda

    def testIncompatibleStructure(self):
        # Define three mutually incompatible values/structures, and assert that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructre a flattened value with an
        #    incompatible structure fails.
        value_tensor = constant_op.constant(42.0)
        s_tensor = structure.Structure.from_value(value_tensor)
        flat_tensor = s_tensor._to_tensor_list(value_tensor)

        value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]],
                                                         values=[1],
                                                         dense_shape=[1, 1])
        s_sparse_tensor = structure.Structure.from_value(value_sparse_tensor)
        flat_sparse_tensor = s_sparse_tensor._to_tensor_list(
            value_sparse_tensor)

        value_nest = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_nest = structure.Structure.from_value(value_nest)
        flat_nest = s_nest._to_tensor_list(value_nest)

        with self.assertRaisesRegexp(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*float32.* and shape \(\)"):
            s_tensor._to_tensor_list(value_sparse_tensor)
        with self.assertRaisesRegexp(
                ValueError,
                r"Value \{.*\} is not convertible to a tensor with "
                r"dtype.*float32.* and shape \(\)"):
            s_tensor._to_tensor_list(value_nest)

        with self.assertRaisesRegexp(TypeError,
                                     "Input must be a SparseTensor"):
            s_sparse_tensor._to_tensor_list(value_tensor)

        with self.assertRaisesRegexp(TypeError,
                                     "Input must be a SparseTensor"):
            s_sparse_tensor._to_tensor_list(value_nest)

        with self.assertRaisesRegexp(
                ValueError,
                "Tensor.* not compatible with the nested structure "
                ".*TensorStructure.*TensorStructure"):
            s_nest._to_tensor_list(value_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.* not compatible with the nested structure "
                ".*TensorStructure.*TensorStructure"):
            s_nest._to_tensor_list(value_sparse_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                r"Cannot convert.*with dtype.*float32.* and shape \(\)"):
            s_tensor._from_tensor_list(flat_sparse_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "TensorStructure corresponds to a single tf.Tensor."):
            s_tensor._from_tensor_list(flat_nest)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensorStructure corresponds to a single tf.variant "
                "vector of length 3."):
            s_sparse_tensor._from_tensor_list(flat_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensorStructure corresponds to a single tf.variant "
                "vector of length 3."):
            s_sparse_tensor._from_tensor_list(flat_nest)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 1."):
            s_nest._from_tensor_list(flat_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 1."):
            s_nest._from_tensor_list(flat_sparse_tensor)

    def testIncompatibleNestedStructure(self):
        # Define three mutually incompatible nested values/structures, and assert
        # that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructre a flattened value with an
        #    incompatible structure fails.

        value_0 = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_0 = structure.Structure.from_value(value_0)
        flat_s_0 = s_0._to_tensor_list(value_0)

        # `value_1` has compatible nested structure with `value_0`, but different
        # classes.
        value_1 = {
            "a":
            constant_op.constant(37.0),
            "b":
            sparse_tensor.SparseTensor(indices=[[0, 0]],
                                       values=[1],
                                       dense_shape=[1, 1])
        }
        s_1 = structure.Structure.from_value(value_1)
        flat_s_1 = s_1._to_tensor_list(value_1)

        # `value_2` has incompatible nested structure with `value_0` and `value_1`.
        value_2 = {
            "a":
            constant_op.constant(37.0),
            "b": (sparse_tensor.SparseTensor(indices=[[0, 0]],
                                             values=[1],
                                             dense_shape=[1, 1]),
                  sparse_tensor.SparseTensor(indices=[[3, 4]],
                                             values=[-1],
                                             dense_shape=[4, 5]))
        }
        s_2 = structure.Structure.from_value(value_2)
        flat_s_2 = s_2._to_tensor_list(value_2)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.* not compatible with the nested structure "
                ".*TensorStructure"):
            s_0._to_tensor_list(value_1)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.*SparseTensor.* not compatible with the "
                "nested structure .*TensorStructure"):
            s_0._to_tensor_list(value_2)

        with self.assertRaisesRegexp(
                ValueError,
                "Tensor.* not compatible with the nested structure "
                ".*SparseTensorStructure"):
            s_1._to_tensor_list(value_0)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.*SparseTensor.* not compatible with the "
                "nested structure .*TensorStructure"):
            s_0._to_tensor_list(value_2)

        # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
        # needs to account for "a" coming before or after "b". It might be worth
        # adding a deterministic repr for these error messages (among other
        # improvements).
        with self.assertRaisesRegexp(
                ValueError,
                "Tensor.*Tensor.* not compatible with the nested structure "
                ".*(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
                "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"
        ):
            s_2._to_tensor_list(value_0)

        with self.assertRaisesRegexp(
                ValueError, "(Tensor.*SparseTensor|SparseTensor.*Tensor).* "
                "not compatible with the nested structure .*"
                "(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
                "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"
        ):
            s_2._to_tensor_list(value_1)

        with self.assertRaisesRegexp(
                ValueError,
                r"Cannot convert.*with dtype.*int32.* and shape \(3,\)"):
            s_0._from_tensor_list(flat_s_1)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 3."):
            s_0._from_tensor_list(flat_s_2)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensorStructure corresponds to a single tf.variant "
                "vector of length 3."):
            s_1._from_tensor_list(flat_s_0)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 3."):
            s_1._from_tensor_list(flat_s_2)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 3 flat values in NestedStructure but got 2."):
            s_2._from_tensor_list(flat_s_0)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 3 flat values in NestedStructure but got 2."):
            s_2._from_tensor_list(flat_s_1)

    @parameterized.named_parameters(
        ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor,
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", dtypes.int32, tensor_shape.matrix(
            2, 2), sparse_tensor.SparseTensor,
         structure.SparseTensorStructure(dtypes.int32, [2, 2])),
        ("Nest", {
            "a": dtypes.float32,
            "b": (dtypes.int32, dtypes.string)
        }, {
            "a": tensor_shape.scalar(),
            "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar())
        }, {
            "a": ops.Tensor,
            "b": (sparse_tensor.SparseTensor, ops.Tensor)
        },
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
                   structure.TensorStructure(dtypes.string, []))
         })),
    )
    def testFromLegacyStructure(self, output_types, output_shapes,
                                output_classes, expected_structure):
        actual_structure = structure.Structure._from_legacy_structure(
            output_types, output_shapes, output_classes)
        self.assertTrue(
            expected_structure.is_compatible_with(actual_structure))
        self.assertTrue(
            actual_structure.is_compatible_with(expected_structure))
예제 #44
0
def _WhereShape(op):
  """Shape function for the Where op."""
  input_shape = op.inputs[0].get_shape()
  return [tensor_shape.matrix(None, input_shape.ndims)]
예제 #45
0
def _MultinomialShape(op):  # pylint: disable=invalid-name
  logits_shape = op.inputs[0].get_shape().with_rank(2)
  batch_size = logits_shape[0]
  num_samples_or_none = tensor_util.constant_value(op.inputs[1])
  return [tensor_shape.matrix(batch_size, num_samples_or_none)]
예제 #46
0
 def output_shapes(self):
   num_elements = tensor_shape.Dimension(None)
   return (tensor_shape.matrix(num_elements, self._row_shape.shape[0] + 1),
           tensor_shape.vector(num_elements),
           tensor_shape.vector(self._row_shape.shape[0] + 1))
예제 #47
0
 def _to_legacy_output_shapes(self):
   # Sneak the dynamic_size and infer_shape values into the legacy shape.
   return (tensor_shape.matrix(self._dynamic_size, self._infer_shape)
           .concatenate(self._element_shape))
def _SparseFeatureCrossShape(unused_op):  # pylint: disable=invalid-name
  return [
      tensor_shape.matrix(None, 2),
      tensor_shape.vector(None),
      tensor_shape.vector(2)
  ]
예제 #49
0
def _WhereShape(op):
  """Shape function for the Where op."""
  input_shape = op.inputs[0].get_shape()
  return [tensor_shape.matrix(None, input_shape.ndims)]
예제 #50
0
class StructureTest(test_base.DatasetTestBase, parameterized.TestCase):

    # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they
    # will be executed before the (eager- or graph-mode) test environment has been
    # set up.
    # pylint: disable=g-long-lambda,protected-access
    @parameterized.parameters(
        (lambda: constant_op.constant(37.0), structure.TensorStructure,
         [dtypes.float32], [[]]),
        (lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0),
         structure.TensorArrayStructure, [dtypes.variant], [None, 3]),
        (lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
         structure.SparseTensorStructure, [dtypes.variant], [None]),
        (lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
         structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
        (lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]
                                                                       ]),
        (lambda: {
            "a":
            constant_op.constant(37.0),
            "b":
            (sparse_tensor.
             SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
             sparse_tensor.SparseTensor(
                 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
        }, structure.NestedStructure,
         [dtypes.float32, dtypes.variant, dtypes.variant], [[], None, None]))
    def testFlatStructure(self, value_fn, expected_structure, expected_types,
                          expected_shapes):
        value = value_fn()
        s = structure.Structure.from_value(value)
        self.assertIsInstance(s, expected_structure)
        self.assertEqual(expected_types, s._flat_types)
        for expected, actual in zip(expected_shapes, s._flat_shapes):
            self.assertTrue(actual.is_compatible_with(expected))
            self.assertTrue(
                tensor_shape.as_shape(expected).is_compatible_with(actual))

    @parameterized.parameters(
        (lambda: constant_op.constant(37.0), lambda: [
            constant_op.constant(38.0),
            array_ops.placeholder(dtypes.float32),
            variables.Variable(100.0), 42.0,
            np.array(42.0, dtype=np.float32)
        ],
         lambda: [constant_op.constant([1.0, 2.0]),
                  constant_op.constant(37)]),
        (lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0), lambda: [
                tensor_array_ops.TensorArray(
                    dtype=dtypes.float32, element_shape=(3, ), size=0),
                tensor_array_ops.TensorArray(
                    dtype=dtypes.float32, element_shape=(3, ), size=10)
            ], lambda: [
                tensor_array_ops.TensorArray(
                    dtype=dtypes.int32, element_shape=(3, ), size=0),
                tensor_array_ops.TensorArray(
                    dtype=dtypes.float32, element_shape=(), size=0)
            ]),
        (lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: [
                sparse_tensor.SparseTensor(indices=[[1, 1], [3, 4]],
                                           values=[10, -1],
                                           dense_shape=[4, 5]),
                sparse_tensor.SparseTensorValue(indices=[[1, 1], [3, 4]],
                                                values=[10, -1],
                                                dense_shape=[4, 5]),
                array_ops.sparse_placeholder(dtype=dtypes.int32),
                array_ops.sparse_placeholder(dtype=dtypes.int32,
                                             shape=[None, None])
            ], lambda: [
                constant_op.constant(37, shape=[4, 5]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1], dense_shape=[5, 6]),
                array_ops.sparse_placeholder(dtype=dtypes.int32,
                                             shape=[None, None, None]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5])
            ]),
        (lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, lambda: [{
            "a": constant_op.constant(15.0),
            "b": constant_op.constant([4, 5, 6])
        }], lambda: [{
            "a": constant_op.constant(15.0),
            "b": constant_op.constant([4, 5, 6, 7])
        }, {
            "a": constant_op.constant(15),
            "b": constant_op.constant([4, 5, 6])
        }, {
            "a":
            constant_op.constant(15),
            "b":
            sparse_tensor.SparseTensor(
                indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
        }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
    )
    @test_util.run_deprecated_v1
    def testIsCompatibleWithStructure(self, original_value_fn,
                                      compatible_values_fn,
                                      incompatible_values_fn):
        original_value = original_value_fn()
        compatible_values = compatible_values_fn()
        incompatible_values = incompatible_values_fn()
        s = structure.Structure.from_value(original_value)
        for compatible_value in compatible_values:
            self.assertTrue(
                s.is_compatible_with(
                    structure.Structure.from_value(compatible_value)))
        for incompatible_value in incompatible_values:
            self.assertFalse(
                s.is_compatible_with(
                    structure.Structure.from_value(incompatible_value)))

    @parameterized.parameters(
        (lambda: constant_op.constant(37.0), ),
        (lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), ),
        (lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)),
        (lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, ),
        (lambda: {
            "a":
            constant_op.constant(37.0),
            "b":
            (sparse_tensor.
             SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
             sparse_tensor.SparseTensor(
                 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
        }, ),
    )
    def testRoundTripConversion(self, value_fn):
        value = value_fn()
        s = structure.Structure.from_value(value)

        def maybe_stack_ta(v):
            if isinstance(v, tensor_array_ops.TensorArray):
                return v.stack()
            else:
                return v

        before = self.evaluate(maybe_stack_ta(value))
        after = self.evaluate(
            maybe_stack_ta(s._from_tensor_list(s._to_tensor_list(value))))

        flat_before = nest.flatten(before)
        flat_after = nest.flatten(after)
        for b, a in zip(flat_before, flat_after):
            if isinstance(b, sparse_tensor.SparseTensorValue):
                self.assertAllEqual(b.indices, a.indices)
                self.assertAllEqual(b.values, a.values)
                self.assertAllEqual(b.dense_shape, a.dense_shape)
            else:
                self.assertAllEqual(b, a)

    # pylint: enable=g-long-lambda

    def testIncompatibleStructure(self):
        # Define three mutually incompatible values/structures, and assert that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructre a flattened value with an
        #    incompatible structure fails.
        value_tensor = constant_op.constant(42.0)
        s_tensor = structure.Structure.from_value(value_tensor)
        flat_tensor = s_tensor._to_tensor_list(value_tensor)

        value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]],
                                                         values=[1],
                                                         dense_shape=[1, 1])
        s_sparse_tensor = structure.Structure.from_value(value_sparse_tensor)
        flat_sparse_tensor = s_sparse_tensor._to_tensor_list(
            value_sparse_tensor)

        value_nest = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_nest = structure.Structure.from_value(value_nest)
        flat_nest = s_nest._to_tensor_list(value_nest)

        with self.assertRaisesRegexp(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*float32.* and shape \(\)"):
            s_tensor._to_tensor_list(value_sparse_tensor)
        with self.assertRaisesRegexp(
                ValueError,
                r"Value \{.*\} is not convertible to a tensor with "
                r"dtype.*float32.* and shape \(\)"):
            s_tensor._to_tensor_list(value_nest)

        with self.assertRaisesRegexp(TypeError,
                                     "Input must be a SparseTensor"):
            s_sparse_tensor._to_tensor_list(value_tensor)

        with self.assertRaisesRegexp(TypeError,
                                     "Input must be a SparseTensor"):
            s_sparse_tensor._to_tensor_list(value_nest)

        with self.assertRaisesRegexp(
                ValueError,
                "Tensor.* not compatible with the nested structure "
                ".*TensorStructure.*TensorStructure"):
            s_nest._to_tensor_list(value_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.* not compatible with the nested structure "
                ".*TensorStructure.*TensorStructure"):
            s_nest._to_tensor_list(value_sparse_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                r"Cannot convert.*with dtype.*float32.* and shape \(\)"):
            s_tensor._from_tensor_list(flat_sparse_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "TensorStructure corresponds to a single tf.Tensor."):
            s_tensor._from_tensor_list(flat_nest)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensorStructure corresponds to a single tf.variant "
                "vector of length 3."):
            s_sparse_tensor._from_tensor_list(flat_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensorStructure corresponds to a single tf.variant "
                "vector of length 3."):
            s_sparse_tensor._from_tensor_list(flat_nest)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 1."):
            s_nest._from_tensor_list(flat_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 1."):
            s_nest._from_tensor_list(flat_sparse_tensor)

    def testIncompatibleNestedStructure(self):
        # Define three mutually incompatible nested values/structures, and assert
        # that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructre a flattened value with an
        #    incompatible structure fails.

        value_0 = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_0 = structure.Structure.from_value(value_0)
        flat_s_0 = s_0._to_tensor_list(value_0)

        # `value_1` has compatible nested structure with `value_0`, but different
        # classes.
        value_1 = {
            "a":
            constant_op.constant(37.0),
            "b":
            sparse_tensor.SparseTensor(indices=[[0, 0]],
                                       values=[1],
                                       dense_shape=[1, 1])
        }
        s_1 = structure.Structure.from_value(value_1)
        flat_s_1 = s_1._to_tensor_list(value_1)

        # `value_2` has incompatible nested structure with `value_0` and `value_1`.
        value_2 = {
            "a":
            constant_op.constant(37.0),
            "b": (sparse_tensor.SparseTensor(indices=[[0, 0]],
                                             values=[1],
                                             dense_shape=[1, 1]),
                  sparse_tensor.SparseTensor(indices=[[3, 4]],
                                             values=[-1],
                                             dense_shape=[4, 5]))
        }
        s_2 = structure.Structure.from_value(value_2)
        flat_s_2 = s_2._to_tensor_list(value_2)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.* not compatible with the nested structure "
                ".*TensorStructure"):
            s_0._to_tensor_list(value_1)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.*SparseTensor.* not compatible with the "
                "nested structure .*TensorStructure"):
            s_0._to_tensor_list(value_2)

        with self.assertRaisesRegexp(
                ValueError,
                "Tensor.* not compatible with the nested structure "
                ".*SparseTensorStructure"):
            s_1._to_tensor_list(value_0)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.*SparseTensor.* not compatible with the "
                "nested structure .*TensorStructure"):
            s_0._to_tensor_list(value_2)

        # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
        # needs to account for "a" coming before or after "b". It might be worth
        # adding a deterministic repr for these error messages (among other
        # improvements).
        with self.assertRaisesRegexp(
                ValueError,
                "Tensor.*Tensor.* not compatible with the nested structure "
                ".*(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
                "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"
        ):
            s_2._to_tensor_list(value_0)

        with self.assertRaisesRegexp(
                ValueError, "(Tensor.*SparseTensor|SparseTensor.*Tensor).* "
                "not compatible with the nested structure .*"
                "(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
                "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"
        ):
            s_2._to_tensor_list(value_1)

        with self.assertRaisesRegexp(
                ValueError,
                r"Cannot convert.*with dtype.*int32.* and shape \(3,\)"):
            s_0._from_tensor_list(flat_s_1)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 3."):
            s_0._from_tensor_list(flat_s_2)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensorStructure corresponds to a single tf.variant "
                "vector of length 3."):
            s_1._from_tensor_list(flat_s_0)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 3."):
            s_1._from_tensor_list(flat_s_2)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 3 flat values in NestedStructure but got 2."):
            s_2._from_tensor_list(flat_s_0)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 3 flat values in NestedStructure but got 2."):
            s_2._from_tensor_list(flat_s_1)

    @parameterized.named_parameters(
        ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor,
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", dtypes.int32, tensor_shape.matrix(
            2, 2), sparse_tensor.SparseTensor,
         structure.SparseTensorStructure(dtypes.int32, [2, 2])),
        ("TensorArray0", dtypes.int32, tensor_shape.as_shape(
            [None, True, 2, 2]), tensor_array_ops.TensorArray,
         structure.TensorArrayStructure(
             dtypes.int32, [2, 2], dynamic_size=None, infer_shape=True)),
        ("TensorArray1", dtypes.int32, tensor_shape.as_shape(
            [True, None, 2, 2]), tensor_array_ops.TensorArray,
         structure.TensorArrayStructure(
             dtypes.int32, [2, 2], dynamic_size=True, infer_shape=None)),
        ("TensorArray2", dtypes.int32,
         tensor_shape.as_shape([True, False, 2, 2
                                ]), tensor_array_ops.TensorArray,
         structure.TensorArrayStructure(
             dtypes.int32, [2, 2], dynamic_size=True, infer_shape=False)),
        ("Nest", {
            "a": dtypes.float32,
            "b": (dtypes.int32, dtypes.string)
        }, {
            "a": tensor_shape.scalar(),
            "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar())
        }, {
            "a": ops.Tensor,
            "b": (sparse_tensor.SparseTensor, ops.Tensor)
        },
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
                   structure.TensorStructure(dtypes.string, []))
         })),
    )
    def testConvertLegacyStructure(self, output_types, output_shapes,
                                   output_classes, expected_structure):
        actual_structure = structure.convert_legacy_structure(
            output_types, output_shapes, output_classes)
        self.assertTrue(
            expected_structure.is_compatible_with(actual_structure))
        self.assertTrue(
            actual_structure.is_compatible_with(expected_structure))

    def testNestedNestedStructure(self):
        # Although `Structure.from_value()` will not construct one, a nested
        # structure containing nested `NestedStructure` objects can occur if a
        # structure is constructed manually.
        s = structure.NestedStructure(
            (structure.TensorStructure(dtypes.int64, []),
             structure.NestedStructure(
                 (structure.TensorStructure(dtypes.float32, []),
                  structure.TensorStructure(dtypes.string, [])))))

        int64_t = constant_op.constant(37, dtype=dtypes.int64)
        float32_t = constant_op.constant(42.0)
        string_t = constant_op.constant("Foo")

        nested_tensors = (int64_t, (float32_t, string_t))

        tensor_list = s._to_tensor_list(nested_tensors)
        for expected, actual in zip([int64_t, float32_t, string_t],
                                    tensor_list):
            self.assertIs(expected, actual)

        (actual_int64_t, (actual_float32_t,
                          actual_string_t)) = s._from_tensor_list(tensor_list)
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = (s._from_compatible_tensor_list(tensor_list))
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)

    @parameterized.named_parameters(
        ("Tensor", structure.TensorStructure(dtypes.float32, []), 32,
         structure.TensorStructure(dtypes.float32, [32])),
        ("TensorUnknown", structure.TensorStructure(dtypes.float32, []), None,
         structure.TensorStructure(dtypes.float32, [None])),
        ("SparseTensor", structure.SparseTensorStructure(
            dtypes.float32, [None]), 32,
         structure.SparseTensorStructure(dtypes.float32, [32, None])),
        ("SparseTensorUnknown",
         structure.SparseTensorStructure(dtypes.float32, [4]), None,
         structure.SparseTensorStructure(dtypes.float32, [None, 4])),
        ("Nest",
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
                   structure.TensorStructure(dtypes.string, []))
         }), 128,
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, [128]),
             "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]),
                   structure.TensorStructure(dtypes.string, [128]))
         })),
    )
    def testBatch(self, element_structure, batch_size,
                  expected_batched_structure):
        batched_structure = element_structure._batch(batch_size)
        self.assertTrue(
            batched_structure.is_compatible_with(expected_batched_structure))
        self.assertTrue(
            expected_batched_structure.is_compatible_with(batched_structure))

    @parameterized.named_parameters(
        ("Tensor", structure.TensorStructure(dtypes.float32, [32]),
         structure.TensorStructure(dtypes.float32, [])),
        ("TensorUnknown", structure.TensorStructure(dtypes.float32, [None]),
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor",
         structure.SparseTensorStructure(dtypes.float32, [32, None]),
         structure.SparseTensorStructure(dtypes.float32, [None])),
        ("SparseTensorUnknown",
         structure.SparseTensorStructure(dtypes.float32, [None, 4]),
         structure.SparseTensorStructure(dtypes.float32, [4])),
        ("Nest",
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, [128]),
             "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]),
                   structure.TensorStructure(dtypes.string, [None]))
         }),
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
                   structure.TensorStructure(dtypes.string, []))
         })),
    )
    def testUnbatch(self, element_structure, expected_unbatched_structure):
        unbatched_structure = element_structure._unbatch()
        self.assertTrue(
            unbatched_structure.is_compatible_with(
                expected_unbatched_structure))
        self.assertTrue(
            expected_unbatched_structure.is_compatible_with(
                unbatched_structure))

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
         lambda: constant_op.constant([1.0, 2.0])),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[0]], values=[13], dense_shape=[2])),
        ("Nest", lambda:
         (constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
          sparse_tensor.SparseTensor(
              indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2])),
         lambda: (constant_op.constant([1.0, 2.0]),
                  sparse_tensor.SparseTensor(
                      indices=[[0]], values=[13], dense_shape=[2]))),
    )
    def testToBatchedTensorList(self, value_fn, element_0_fn):
        batched_value = value_fn()
        s = structure.Structure.from_value(batched_value)
        batched_tensor_list = s._to_batched_tensor_list(batched_value)

        # The batch dimension is 2 for all of the test cases.
        # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT
        # tensors in which we store sparse tensors.
        for t in batched_tensor_list:
            if t.dtype != dtypes.variant:
                self.assertEqual(2, self.evaluate(array_ops.shape(t)[0]))

        # Test that the 0th element from the unbatched tensor is equal to the
        # expected value.
        expected_element_0 = self.evaluate(element_0_fn())
        unbatched_s = s._unbatch()
        actual_element_0 = unbatched_s._from_tensor_list(
            [t[0] for t in batched_tensor_list])

        for expected, actual in zip(nest.flatten(expected_element_0),
                                    nest.flatten(actual_element_0)):
            if sparse_tensor.is_sparse(expected):
                self.assertSparseValuesEqual(expected, actual)
            else:
                self.assertAllEqual(expected, actual)