예제 #1
0
def _visit_type(type_signature, function):

  def inner_function(inner_type):
    function(inner_type)
    return inner_type, False

  type_transformations.transform_type_postorder(type_signature, inner_function)
예제 #2
0
def _visit_postorder(type_signature: computation_types.Type,
                     function: Callable[[computation_types.Type], None]):
    py_typecheck.check_type(type_signature, computation_types.Type)

    def _visit(inner_type):
        function(inner_type)
        return inner_type, False

    type_transformations.transform_type_postorder(type_signature, _visit)
 def test_recurses_under_sequence(self):
     orig_type = computation_types.SequenceType([tf.int32])
     expected_type = computation_types.SequenceType([tf.float32])
     result_type, mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_tensor_to_float)
     noop_type, not_mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_abstract_type_to_tensor)
     self.assertEqual(result_type, expected_type)
     self.assertEqual(noop_type, orig_type)
     self.assertTrue(mutated)
     self.assertFalse(not_mutated)
 def test_transforms_unnamed_tuple_type(self):
     orig_type = computation_types.StructType([tf.int32, tf.float64])
     expected_type = computation_types.StructType([tf.float32, tf.float32])
     result_type, mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_tensor_to_float)
     noop_type, not_mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_abstract_type_to_tensor)
     self.assertEqual(result_type, expected_type)
     self.assertEqual(noop_type, orig_type)
     self.assertTrue(mutated)
     self.assertFalse(not_mutated)
 def test_transforms_tensor(self):
     orig_type = computation_types.TensorType(tf.int32)
     expected_type = computation_types.TensorType(tf.float32)
     result_type, mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_tensor_to_float)
     noop_type, not_mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_abstract_type_to_tensor)
     self.assertEqual(result_type, expected_type)
     self.assertEqual(noop_type, orig_type)
     self.assertTrue(mutated)
     self.assertFalse(not_mutated)
예제 #6
0
 def test_transforms_federated_type(self):
   orig_type = computation_types.FederatedType(tf.int32,
                                               placement_literals.CLIENTS)
   expected_type = computation_types.FederatedType(tf.float32,
                                                   placement_literals.CLIENTS)
   result_type, mutated = type_transformations.transform_type_postorder(
       orig_type, _convert_tensor_to_float)
   noop_type, not_mutated = type_transformations.transform_type_postorder(
       orig_type, _convert_abstract_type_to_tensor)
   self.assertEqual(result_type, expected_type)
   self.assertEqual(noop_type, orig_type)
   self.assertTrue(mutated)
   self.assertFalse(not_mutated)
예제 #7
0
 def test_transforms_named_tuple_type_preserving_tuple_container(self):
   orig_type = computation_types.NamedTupleTypeWithPyContainerType(
       [('a', tf.int32), ('b', tf.float64)], dict)
   expected_type = computation_types.NamedTupleTypeWithPyContainerType(
       [('a', tf.float32), ('b', tf.float32)], dict)
   result_type, mutated = type_transformations.transform_type_postorder(
       orig_type, _convert_tensor_to_float)
   noop_type, not_mutated = type_transformations.transform_type_postorder(
       orig_type, _convert_abstract_type_to_tensor)
   self.assertEqual(result_type, expected_type)
   self.assertEqual(noop_type, orig_type)
   self.assertTrue(mutated)
   self.assertFalse(not_mutated)
예제 #8
0
 def test_recurses_under_named_tuple_type(self):
   orig_type = computation_types.to_type([[('a', tf.int32),
                                           ('b', tf.float64)]])
   expected_type = computation_types.to_type([[('a', tf.float32),
                                               ('b', tf.float32)]])
   result_type, mutated = type_transformations.transform_type_postorder(
       orig_type, _convert_tensor_to_float)
   noop_type, not_mutated = type_transformations.transform_type_postorder(
       orig_type, _convert_abstract_type_to_tensor)
   self.assertEqual(result_type, expected_type)
   self.assertEqual(noop_type, orig_type)
   self.assertTrue(mutated)
   self.assertFalse(not_mutated)
예제 #9
0
def is_binary_op_with_upcast_compatible_pair(
        possibly_nested_type: Optional[computation_types.Type],
        type_to_upcast: computation_types.Type) -> bool:
    """Checks unambiguity in applying `type_to_upcast` to `possibly_nested_type`.

  That is, checks that either these types are equivalent and contain only
  tuples and tensors, or that
  `possibly_nested_type` is perhaps a nested structure containing only tensors
  with `dtype` of `type_to_upcast` at the leaves, where `type_to_upcast` must
  be a scalar tensor type. Notice that this relationship is not symmetric,
  since binary operators need not respect this symmetry in general.
  For example, it makes perfect sence to divide a nested structure of tensors
  by a scalar, but not the other way around.

  Args:
    possibly_nested_type: A `computation_types.Type`, or `None`.
    type_to_upcast: A `computation_types.Type`, or `None`.

  Returns:
    Boolean indicating whether `type_to_upcast` can be upcast to
    `possibly_nested_type` in the manner described above.
  """
    if possibly_nested_type is not None:
        py_typecheck.check_type(possibly_nested_type, computation_types.Type)
    if type_to_upcast is not None:
        py_typecheck.check_type(type_to_upcast, computation_types.Type)
    if not (is_generic_op_compatible_type(possibly_nested_type)
            and is_generic_op_compatible_type(type_to_upcast)):
        return False
    if possibly_nested_type is None:
        return type_to_upcast is None
    if possibly_nested_type.is_equivalent_to(type_to_upcast):
        return True
    if not (isinstance(type_to_upcast, computation_types.TensorType)
            and type_to_upcast.shape == tf.TensorShape(())):
        return False

    types_are_ok = [True]

    only_allowed_dtype = type_to_upcast.dtype

    def _check_tensor_types(type_spec):
        if isinstance(type_spec, computation_types.TensorType
                      ) and type_spec.dtype != only_allowed_dtype:
            types_are_ok[0] = False
        return type_spec, False

    type_transformations.transform_type_postorder(possibly_nested_type,
                                                  _check_tensor_types)

    return types_are_ok[0]
예제 #10
0
    async def _embed_tf_secure_sum_mask_value(self, type_spec,
                                              extended_bitwidth):
        """Construct a CompiledComputation with the mask for the secure sum."""
        def transform_to_scalar_type_spec(t):
            """Converts all `tff.TensorType` to scalar shapes."""
            if not t.is_tensor():
                return t, False
            return computation_types.TensorType(dtype=t.dtype, shape=[]), True

        mask_type, _ = type_transformations.transform_type_postorder(
            type_spec, transform_to_scalar_type_spec)

        def compute_mask(bits, type_spec):
            mask_value = 2**bits - 1
            if mask_value > type_spec.dtype.max:
                logging.warning(
                    'Emulated secure sum mask exceeds maximum value of '
                    'dtype. Dtype %s, mask bits: %d', type_spec.dtype, bits)
                mask_value = type_spec.dtype.max
            return tf.constant(mask_value, type_spec.dtype)

        mask_value = _map_numpy_or_structure(extended_bitwidth,
                                             mask_type,
                                             fn=compute_mask)
        logging.debug('Emulated secure sum using mask: %s', mask_value)
        return await executor_utils.embed_tf_constant(self._executor,
                                                      mask_type, mask_value)
예제 #11
0
def _check_container_compat_with_tf_nest(type_spec: computation_types.Type):
    """Asserts that all `StructTypes` with names have OrderedDict containers."""
    def _names_are_in_sorted_order(name_sequence: Sequence[str]) -> bool:
        return sorted(name_sequence) == name_sequence

    def _check_ordereddict_container_for_struct(type_to_check):
        if not type_to_check.is_struct():
            return type_to_check, False
        # We can't use `dir` here, since it sorts the names before returning. We
        # also must filter to names which are actually present.
        names_in_sequence_order = structure.name_list(type_to_check)
        names_are_sorted = _names_are_in_sorted_order(names_in_sequence_order)
        has_no_names = not bool(names_in_sequence_order)
        if has_no_names or (names_in_sequence_order and names_are_sorted):
            # If alphabetical order matches sequence order, TFF's deserialization will
            # traverse the structure correctly; there is no ambiguity here. On the
            # other hand, if there are no names, sequence order is the only method of
            # traversal, so there is no ambiguity here either.
            return type_to_check, False
        elif not type_to_check.is_struct_with_python():
            raise ValueError(
                'Attempting to serialize a named struct type with '
                'ambiguous traversal order (sequence order distinct '
                'from alphabetical order) without a Python container; '
                'this is an unsafe operation, as TFF cannot determine '
                'the intended traversal order after deserializing the '
                'proto due to inconsistent behavior of tf.nest.')

        container_type = computation_types.StructWithPythonType.get_container_type(
            type_to_check)
        if (not names_are_sorted
            ) and container_type is not collections.OrderedDict:
            raise ValueError(
                'Attempted to serialize a dataset yielding named '
                'elements in non-sorted sequence order with '
                f'non-OrderedDict container (type {container_type}). '
                'This is an ambiguous operation; `tf.nest` behaves in '
                'a manner which depends on the Python type of this '
                'container, so coercing the dataset reconstructed '
                'from the resulting Value proto depends on assuming a '
                'single Python type here. Please prefer to use '
                '`collections.OrderedDict` containers for the elements '
                'your dataset yields.')
        return type_to_check, False

    type_transformations.transform_type_postorder(
        type_spec, _check_ordereddict_container_for_struct)
예제 #12
0
def _build_reservoir_type(
        sample_value_type: computation_types.Type) -> computation_types.Type:
    """Create the TFF type for the reservoir's state.

  `UnweightedReservoirSamplingFactory` will use this type as the "state" type in
  a `tff.federated_aggregate` (an input to `accumulate`, `merge` and `report`).

  Args:
    sample_value_type: The `tff.Type` of the values that will be aggregated from
      clients.

  Returns:
    A `collection.OrderedDict` with three keys:
      random_seed: A 2-tuple of `tf.int64` scalars. This keeps track of the
        `seed` parameter for `tf.random.stateless_uniform` calls for
        sampling during aggregation.
      random_values: A 1-d tensor of `int32` values randomly generated from
        `tf.random.stateless_uniform`. The size of the tensor will be the
        same as the first dimenion of each leaf in `samples` (these can be
        thought of as parallel lists). These values are used to determine
        whether a sample stays in the reservoir, or is evicted, as the values
        are aggregated. If the i-th value of this list is evicted, then the
        i-th (in the first dimension) tensor in each of the `samples` structure
        leaves must be evicted.
      samples: A tensor or structure of tensors representing the actual sampled
        values. If a structure, the shape of the structure matches that of
        `sample_value_type`. All tensors have one additional dimenion prepended
        which has an unknown size. This will be used to concatenate samples and
        store them in the reservoir.
  """

    # TODO(b/181365504): relax this to allow `StructType` once a `Struct` can be
    # returned from `tf.function` decorated methods.
    def is_tesnor_or_struct_with_py_type(t: computation_types.Type) -> bool:
        return t.is_tensor() or t.is_struct_with_python()

    if not type_analysis.contains_only(sample_value_type,
                                       is_tesnor_or_struct_with_py_type):
        raise TypeError(
            'Cannot create a reservoir for type structure. Sample type '
            'must only contain `TensorType` or `StructWithPythonType`, '
            f'got a {sample_value_type!r}.')

    def add_unknown_dimension(t):
        if t.is_tensor():
            return (computation_types.TensorType(dtype=t.dtype,
                                                 shape=[None] + t.shape), True)
        return t, False

    # TODO(b/181155367): creating a value from a type for the `zero` is a common
    # pattern for users of `tff.federated_aggregate` that could be made easier
    # for TFF users. Replace this once such helper exists.
    return computation_types.to_type(
        collections.OrderedDict(
            random_seed=computation_types.TensorType(tf.int64, shape=[2]),
            random_values=computation_types.TensorType(tf.int32, shape=[None]),
            samples=type_transformations.transform_type_postorder(
                sample_value_type, add_unknown_dimension)[0]))
예제 #13
0
 def _remove_lambda_placement(comp):
   """Removes placement from Lambda's parameter."""
   if comp.parameter_name is None:
     new_parameter_type = None
   else:
     new_parameter_type, _ = type_transformations.transform_type_postorder(
         comp.parameter_type, _remove_placement_from_type)
   return building_blocks.Lambda(comp.parameter_name, new_parameter_type,
                                 comp.result)
예제 #14
0
def count_tensors_in_type(
    type_spec: computation_types.Type,
    tensor_filter: Optional[Callable[[computation_types.TensorType],
                                     bool]] = None
) -> collections.OrderedDict:
    """Counts tensors and fully-specified elements under `type_spec`.

  Args:
    type_spec: Instance of `computation_types.Type` to count tensors under.
    tensor_filter: Optional filtering function. Callable which takes an argument
      of type `computation_types.TensorType` and returns a boolean. If
      specified, only tensor type which pass this filter (IE, on which this
      function returns `True`) will be counted.

  Returns:
    A `collections.OrderedDict` with three parameters. The first, `tensors`, is
    the count of all `computation_types.TensorType` (passing `tensor_filter`
    if this argument is specified). The second, `parameters`, is the count
    of all fully-specified parameters of these tensors. Note that this implies
    any tensor with a `None` dimension (IE, of unspecified size) will not be
    counted. The third counts how many tensors fall into this category (that
    is, now many have unspecified size).
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if tensor_filter is None:
        tensor_filter = lambda _: True
    py_typecheck.check_callable(tensor_filter)

    tensors_and_params = collections.OrderedDict(num_tensors=0,
                                                 parameters=0,
                                                 num_unspecified_tensors=0)

    def _capture_tensors(type_signature):
        if type_signature.is_tensor() and tensor_filter(type_signature):
            tensors_and_params['num_tensors'] += 1
            num_parameters = type_signature.shape.num_elements()
            if num_parameters is not None:
                tensors_and_params['parameters'] += num_parameters
            else:
                tensors_and_params['num_unspecified_tensors'] += 1
        return type_signature, False

    type_transformations.transform_type_postorder(type_spec, _capture_tensors)
    return tensors_and_params
def _remove_client_all_equals_from_type(type_signature):

  def _transform(inner_type):
    if (inner_type.is_federated() and inner_type.placement.is_clients() and
        inner_type.all_equal):
      return computation_types.FederatedType(inner_type.member,
                                             inner_type.placement, False), True
    return inner_type, False

  return type_transformations.transform_type_postorder(type_signature,
                                                       _transform)[0]
 def test_raises_on_non_type_first_arg(self):
     with self.assertRaises(TypeError):
         type_transformations.transform_type_postorder(tf.int32, None)
 def test_raises_on_none_function(self):
     with self.assertRaises(TypeError):
         type_transformations.transform_type_postorder(
             computation_types.TensorType(tf.int32), None)
 def test_raises_on_none_type(self):
     with self.assertRaises(TypeError):
         type_transformations.transform_type_postorder(None, lambda x: x)
 def test_updates_mutated_bit_at_tuple(self):
     orig_type = computation_types.StructType([tf.int32, tf.float64])
     _, mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_tuple_to_tensor)
     self.assertTrue(mutated)
예제 #20
0
 def _remove_reference_placement(comp):
   """Unwraps placement from references and updates unbound reference info."""
   new_type, _ = type_transformations.transform_type_postorder(
       comp.type_signature, _remove_placement_from_type)
   return building_blocks.Reference(comp.name, new_type)
 def test_updates_mutated_bit_at_function(self):
     orig_type = computation_types.FunctionType(tf.int32, tf.int64)
     _, mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_function_to_tensor)
     self.assertTrue(mutated)
 def test_updates_mutated_bit_at_sequence(self):
     orig_type = computation_types.SequenceType(tf.int32)
     _, mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_sequence_to_tensor)
     self.assertTrue(mutated)
 def test_updates_mutated_bit_at_federated(self):
     orig_type = computation_types.FederatedType(tf.int32,
                                                 placements.CLIENTS)
     _, mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_federated_to_tensor)
     self.assertTrue(mutated)
예제 #24
0
def create_constant(scalar_value,
                    type_spec: computation_types.Type) -> ProtoAndType:
  """Returns a tensorflow computation returning a constant `scalar_value`.

  The returned computation has the type signature `( -> T)`, where `T` is
  `type_spec`.

  `scalar_value` must be a scalar, and cannot be a float if any of the tensor
  leaves of `type_spec` contain an integer data type. `type_spec` must contain
  only named tuples and tensor types, but these can be arbitrarily nested.

  Args:
    scalar_value: A scalar value to place in all the tensor leaves of
      `type_spec`.
    type_spec: A `computation_types.Type` to use as the argument to the
      constructed binary operator; must contain only named tuples and tensor
      types.

  Raises:
    TypeError: If the constraints of `type_spec` are violated.
  """
  if not type_analysis.is_generic_op_compatible_type(type_spec):
    raise TypeError(
        'Type spec {} cannot be constructed as a TensorFlow constant in TFF; '
        ' only nested tuples and tensors are permitted.'.format(type_spec))
  inferred_scalar_value_type = type_conversions.infer_type(scalar_value)
  if (not inferred_scalar_value_type.is_tensor() or
      inferred_scalar_value_type.shape != tf.TensorShape(())):
    raise TypeError(
        'Must pass a scalar value to `create_tensorflow_constant`; encountered '
        'a value {}'.format(scalar_value))
  tensor_dtypes_in_type_spec = []

  def _pack_dtypes(type_signature):
    """Appends dtype of `type_signature` to nonlocal variable."""
    if type_signature.is_tensor():
      tensor_dtypes_in_type_spec.append(type_signature.dtype)
    return type_signature, False

  type_transformations.transform_type_postorder(type_spec, _pack_dtypes)

  if (any(x.is_integer for x in tensor_dtypes_in_type_spec) and
      not inferred_scalar_value_type.dtype.is_integer):
    raise TypeError(
        'Only integers can be used as scalar values if our desired constant '
        'type spec contains any integer tensors; passed scalar {} of dtype {} '
        'for type spec {}.'.format(scalar_value,
                                   inferred_scalar_value_type.dtype, type_spec))

  result_type = type_spec

  def _create_result_tensor(type_spec, scalar_value):
    """Packs `scalar_value` into `type_spec` recursively."""
    if type_spec.is_tensor():
      type_spec.shape.assert_is_fully_defined()
      result = tf.constant(
          scalar_value, dtype=type_spec.dtype, shape=type_spec.shape)
    else:
      elements = []
      for _, type_element in structure.iter_elements(type_spec):
        elements.append(_create_result_tensor(type_element, scalar_value))
      result = elements
    return result

  with tf.Graph().as_default() as graph:
    result = _create_result_tensor(result_type, scalar_value)
    _, result_binding = tensorflow_utils.capture_result_from_graph(
        result, graph)

  type_signature = computation_types.FunctionType(None, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=None,
      result=result_binding)
  return _tensorflow_comp(tensorflow, type_signature)
def create_constant(
        value, type_spec: computation_types.Type) -> ComputationProtoAndType:
    """Returns a tensorflow computation returning a constant `value`.

  The returned computation has the type signature `( -> T)`, where `T` is
  `type_spec`.

  `value` must be a value convertible to a tensor or a structure of values, such
  that the dtype and shapes match `type_spec`. `type_spec` must contain only
  named tuples and tensor types, but these can be arbitrarily nested.

  Args:
    value: A value to embed as a constant in the tensorflow graph.
    type_spec: A `computation_types.Type` to use as the argument to the
      constructed binary operator; must contain only named tuples and tensor
      types.

  Raises:
    TypeError: If the constraints of `type_spec` are violated.
  """
    if not type_analysis.is_generic_op_compatible_type(type_spec):
        raise TypeError(
            'Type spec {} cannot be constructed as a TensorFlow constant in TFF; '
            ' only nested tuples and tensors are permitted.'.format(type_spec))
    inferred_value_type = type_conversions.infer_type(value)
    if (inferred_value_type.is_struct()
            and not type_spec.is_assignable_from(inferred_value_type)):
        raise TypeError(
            'Must pass a only tensor or structure of tensor values to '
            '`create_tensorflow_constant`; encountered a value {v} with inferred '
            'type {t!r}, but needed {s!r}'.format(v=value,
                                                  t=inferred_value_type,
                                                  s=type_spec))
    if inferred_value_type.is_struct():
        value = structure.from_container(value, recursive=True)
    tensor_dtypes_in_type_spec = []

    def _pack_dtypes(type_signature):
        """Appends dtype of `type_signature` to nonlocal variable."""
        if type_signature.is_tensor():
            tensor_dtypes_in_type_spec.append(type_signature.dtype)
        return type_signature, False

    type_transformations.transform_type_postorder(type_spec, _pack_dtypes)

    if (any(x.is_integer for x in tensor_dtypes_in_type_spec)
            and (inferred_value_type.is_tensor()
                 and not inferred_value_type.dtype.is_integer)):
        raise TypeError(
            'Only integers can be used as scalar values if our desired constant '
            'type spec contains any integer tensors; passed scalar {} of dtype {} '
            'for type spec {}.'.format(value, inferred_value_type.dtype,
                                       type_spec))

    result_type = type_spec

    def _create_result_tensor(type_spec, value):
        """Packs `value` into `type_spec` recursively."""
        if type_spec.is_tensor():
            type_spec.shape.assert_is_fully_defined()
            result = tf.constant(value,
                                 dtype=type_spec.dtype,
                                 shape=type_spec.shape)
        else:
            elements = []
            if inferred_value_type.is_struct():
                # Copy the leaf values according to the type_spec structure.
                for (name, elem_type), value in zip(
                        structure.iter_elements(type_spec), value):
                    elements.append(
                        (name, _create_result_tensor(elem_type, value)))
            else:
                # "Broadcast" the value to each level of the type_spec structure.
                for _, elem_type in structure.iter_elements(type_spec):
                    elements.append(
                        (None, _create_result_tensor(elem_type, value)))
            result = structure.Struct(elements)
        return result

    with tf.Graph().as_default() as graph:
        result = _create_result_tensor(result_type, value)
        _, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(None, result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=None,
                               result=result_binding)
    return _tensorflow_comp(tensorflow, type_signature)
예제 #26
0
def _deserialize_dataset_from_graph_def(serialized_graph_def: bytes,
                                        element_type: computation_types.Type):
    """Deserializes a serialized `tf.compat.v1.GraphDef` to a `tf.data.Dataset`.

  Args:
    serialized_graph_def: `bytes` object produced by
      `tensorflow_serialization.serialize_dataset`
    element_type: a `tff.Type` object representing the type structure of the
      elements yielded from the dataset.

  Returns:
    A `tf.data.Dataset` instance.
  """
    py_typecheck.check_type(element_type, computation_types.Type)
    type_analysis.check_tensorflow_compatible_type(element_type)

    def transform_to_tff_known_type(
        type_spec: computation_types.Type
    ) -> Tuple[computation_types.Type, bool]:
        """Transforms `StructType` to `StructWithPythonType`."""
        if type_spec.is_struct() and not type_spec.is_struct_with_python():
            field_is_named = tuple(
                name is not None
                for name, _ in structure.iter_elements(type_spec))
            has_names = any(field_is_named)
            is_all_named = all(field_is_named)
            if is_all_named:
                return computation_types.StructWithPythonType(
                    elements=structure.iter_elements(type_spec),
                    container_type=collections.OrderedDict), True
            elif not has_names:
                return computation_types.StructWithPythonType(
                    elements=structure.iter_elements(type_spec),
                    container_type=tuple), True
            else:
                raise TypeError(
                    'Cannot represent TFF type in TF because it contains '
                    f'partially named structures. Type: {type_spec}')
        return type_spec, False

    if element_type.is_struct():
        # TF doesn't suppor `structure.Strut` types, so we must transform the
        # `StructType` into a `StructWithPythonType` for use as the
        # `tf.data.Dataset.element_spec` later.
        tf_compatible_type, _ = type_transformations.transform_type_postorder(
            element_type, transform_to_tff_known_type)
    else:
        # We've checked this is only a struct or tensors, so we know this is a
        # `TensorType` here and will use as-is.
        tf_compatible_type = element_type

    def type_to_tensorspec(t: computation_types.TensorType) -> tf.TensorSpec:
        return tf.TensorSpec(shape=t.shape, dtype=t.dtype)

    element_spec = type_conversions.structure_from_tensor_type_tree(
        type_to_tensorspec, tf_compatible_type)
    ds = tf.data.experimental.from_variant(
        tf.raw_ops.DatasetFromGraph(graph_def=serialized_graph_def),
        structure=element_spec)
    # If a serialized dataset had elements of nested structes of tensors (e.g.
    # `dict`, `OrderedDict`), the deserialized dataset will return `dict`,
    # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion).
    #
    # Since the dataset will only be used inside TFF, we wrap the dictionary
    # coming from TF in an `OrderedDict` when necessary (a type that both TF and
    # TFF understand), using the field order stored in the TFF type stored during
    # serialization.
    return tensorflow_utils.coerce_dataset_elements_to_tff_type_spec(
        ds, tf_compatible_type)