Exemple #1
0
 def testFullTypesForFlatTensors(self):
     spec = TwoTensorsSpec([5], dtypes.int32, [5, 8], dtypes.float32, "red")
     full_type_list = fulltypes_for_flat_tensors(spec)
     expect = [
         full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET),
         full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET)
     ]
     self.assertEqual(len(spec._flat_tensor_specs), len(full_type_list))
     self.assertEqual(expect, full_type_list)
    def test_fn():
      ta = tensor_array_ops.TensorArray(dtypes.int32, size=1)
      h = math_ops.cast(ta.flow, dtypes.variant)

      t = full_type_pb2.FullTypeDef(
          type_id=full_type_pb2.TFT_PRODUCT,
          args=[full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_ARRAY)])
      h.op.experimental_set_type(t)

      ta = tensor_array_ops.TensorArray(dtypes.int32, flow=h)
      ta = ta.write(0, constant_op.constant(1))
      return ta.stack()
Exemple #3
0
def full_type_from_spec(element_spec):
    """Returns a FullTypeDef for the element tensor representation.

  Args:
     element_spec: A nested structure of `tf.TypeSpec` objects representing the
       element type specification.

  Returns:
    A FullTypeDef for the element tensor representation.
  """
    args = []
    for ts in nest.flatten(element_spec):
        if isinstance(ts, NoneTensorSpec):
            # NoneTensorSpec does not correspond to an output
            continue
        if isinstance(ts, sparse_tensor.SparseTensorSpec):
            # Currently, this represents a SparseTensor spec as a single ouput (that
            # is a variant) as a TFT_TENSOR. When shape information is added to
            # fulltype, either the shape needs to reflect this
            # (e.g. TFT_TENSOR[..., shape=COOSparseShape]) or a new TFT_SPARSE data
            # type should be created.
            fts = get_flat_tensor_specs(ts)
            if (len(fts) != 1) or (fts[0].dtype != dtypes.variant):
                raise TypeError("Only sparse tensors as variants is supported")
            type_id = full_type_pb2.TFT_TENSOR
        elif isinstance(ts, ragged_tensor.RaggedTensorSpec):
            type_id = full_type_pb2.TFT_RAGGED
        elif isinstance(ts, tensor_spec.TensorSpec):
            type_id = full_type_pb2.TFT_TENSOR
        else:
            # The intent of defaulting to TFT_UNSET is so other cases can fallback
            # to the behavior prior to full type or a reasonable default. Users
            # can define their own type specs, so it is important to have a reasonable
            # default.
            type_id = full_type_pb2.TFT_UNSET
        if type_id != full_type_pb2.TFT_UNSET:
            args.append(
                full_type_pb2.FullTypeDef(
                    type_id=type_id,
                    args=[
                        full_type_pb2.FullTypeDef(type_id=DT_TO_FT.get(
                            ts.dtype, full_type_pb2.TFT_ANY))
                    ]))
        else:
            args.append(
                full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET))
    element_type = full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_PRODUCT,
                                             args=args)
    return element_type
Exemple #4
0
 def testFullTypesForFlatTensors(self):
   a = TwoComposites(
       ragged_factory_ops.constant([[1, 2], [3]]),
       ragged_factory_ops.constant([[5], [6, 7, 8]]))
   a_spec = type_spec.type_spec_from_value(a)
   full_type_list = fulltypes_for_flat_tensors(a_spec)
   expect = [
       full_type_pb2.FullTypeDef(
           type_id=full_type_pb2.TFT_RAGGED,
           args=[full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32)]),
       full_type_pb2.FullTypeDef(
           type_id=full_type_pb2.TFT_RAGGED,
           args=[full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32)]),
   ]
   self.assertEqual(len(a_spec._flat_tensor_specs), len(full_type_list))
   self.assertEqual(expect, full_type_list)
Exemple #5
0
def iterator_full_type_from_spec(element_spec):
    """Returns a FullTypeDef for an iterator for the elements.

  Args:
     element_spec: A nested structure of `tf.TypeSpec` objects representing the
       element type specification.

  Returns:
    A FullTypeDef for an iterator for the element tensor representation.
  """
    args = fulltypes_for_flat_tensors(element_spec)
    return full_type_pb2.FullTypeDef(
        type_id=full_type_pb2.TFT_PRODUCT,
        args=[
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_ITERATOR,
                args=[
                    full_type_pb2.FullTypeDef(
                        type_id=full_type_pb2.TFT_PRODUCT, args=args)
                ])
        ])
Exemple #6
0
def _translate_to_fulltype_for_flat_tensors(
        spec: type_spec.TypeSpec) -> List[full_type_pb2.FullTypeDef]:
    """Convert a TypeSec to a list of FullTypeDef.

  The FullTypeDef created corresponds to the encoding used with datasets
  (and map_fn) that uses variants (and not FullTypeDef corresponding to the
  default "component" encoding).

  Currently, the only use of this is for information about the contents of
  ragged tensors, so only ragged tensors return useful full type information
  and other types return TFT_UNSET. While this could be improved in the future,
  this function is intended for temporary use and expected to be removed
  when type inference support is sufficient.

  Args:
    spec: A TypeSpec for one element of a dataset or map_fn.

  Returns:
    A list of FullTypeDef corresponding to SPEC. The length of this list
    is always the same as the length of spec._flat_tensor_specs.
  """
    if isinstance(spec, RaggedTensorSpec):
        dt = spec.dtype
        elem_t = _DT_TO_FT.get(dt)
        if elem_t is None:
            logging.vlog(1, "dtype %s that has no conversion to fulltype.", dt)
        elif elem_t == full_type_pb2.TFT_LEGACY_VARIANT:
            logging.vlog(
                1, "Ragged tensors containing variants are not supported.", dt)
        else:
            assert len(spec._flat_tensor_specs) == 1  # pylint: disable=protected-access
            return [
                full_type_pb2.FullTypeDef(
                    type_id=full_type_pb2.TFT_RAGGED,
                    args=[full_type_pb2.FullTypeDef(type_id=elem_t)])
            ]
    return [
        full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET)
        for t in spec._flat_tensor_specs  # pylint: disable=protected-access
    ]
Exemple #7
0
def _set_handle_data(list_handle, element_shape, element_dtype):
    """Sets type information on `list_handle` for consistency with graphs."""
    # TODO(b/169968286): It would be better if we had a consistent story for
    # creating handle data from eager operations (shared with VarHandleOp).
    if isinstance(list_handle, ops.EagerTensor):
        if tensor_util.is_tf_type(element_shape):
            element_shape = tensor_shape.TensorShape(None)
        elif not isinstance(element_shape, tensor_shape.TensorShape):
            element_shape = tensor_shape.TensorShape(element_shape)
        handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData(
        )
        handle_data.is_set = True
        # TODO(b/191472076): This duplicates type inference. Clean up.
        handle_data.shape_and_type.append(
            cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
                shape=element_shape.as_proto(),
                dtype=element_dtype.as_datatype_enum,
                type=full_type_pb2.FullTypeDef(
                    type_id=full_type_pb2.TFT_ARRAY)))
        list_handle._handle_data = handle_data  # pylint: disable=protected-access
 def testFullTypesForFlatTensors(self):
     spec = tensor_spec.TensorSpec([1], np.float32)
     full_type_list = fulltypes_for_flat_tensors(spec)
     expect = [full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET)]
     self.assertEqual(len(spec._flat_tensor_specs), len(full_type_list))
     self.assertEqual(expect, full_type_list)
    def testFulTypesForFlatTensors(self):
        # Note that the batchable tensor list encoding for a StructuredTensor
        # contains a separate tensor for each leaf field.
        # In this example, _flat_tensor_specs in class StructuredTensorSpec is
        # called three times and it returns results with length 2, 3 and 11
        # for "g", "h" and `struct` respectively.
        fields = self._lambda_for_fields()
        rank = 4
        if callable(fields):
            fields = fields(
            )  # deferred construction: fields may include tensors.

        struct = StructuredTensor.from_fields_and_rank(fields, rank)
        spec = type_spec.type_spec_from_value(struct)
        flat_specs = spec._flat_tensor_specs
        fulltype = fulltypes_for_flat_tensors(spec)
        expected_ft_list = [
            # a, b
            full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET),
            full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET),
            # c, d, e, f
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UINT8)
                ]),
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_FLOAT)
                ]),
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_FLOAT)
                ]),
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_FLOAT)
                ]),
            # g
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32)
                ]),
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32)
                ]),
            # h
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32)
                ]),
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32)
                ]),
            full_type_pb2.FullTypeDef(
                type_id=full_type_pb2.TFT_RAGGED,
                args=[
                    full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_INT32)
                ]),
        ]
        self.assertEqual(len(expected_ft_list), len(flat_specs))
        self.assertEqual(fulltype, expected_ft_list)
 def testFullTypesForFlatTensors(self, dt):
     st_spec = sparse_tensor.SparseTensorSpec(dtype=dt)
     full_type_list = fulltypes_for_flat_tensors(st_spec)
     expect = [full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET)]
     self.assertEqual(len(st_spec._flat_tensor_specs), len(full_type_list))
     self.assertEqual(expect, full_type_list)
Exemple #11
0
def fulltype_list_to_product(fulltype_list):
    """Convert a list of FullType Def into a single FullType Def."""
    return full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_PRODUCT,
                                     args=fulltype_list)