Exemplo n.º 1
0
 def transform_to_tff_known_type(
     type_spec: computation_types.Type) -> Tuple[computation_types.Type, bool]:
   """Transforms `StructType` to `StructWithPythonType`."""
   if type_spec.is_struct() and not type_spec.is_struct_with_python():
     field_is_named = tuple(
         name is not None for name, _ in structure.iter_elements(type_spec))
     has_names = any(field_is_named)
     is_all_named = all(field_is_named)
     if is_all_named:
       return computation_types.StructWithPythonType(
           elements=structure.iter_elements(type_spec),
           container_type=collections.OrderedDict), True
     elif not has_names:
       return computation_types.StructWithPythonType(
           elements=structure.iter_elements(type_spec),
           container_type=tuple), True
     else:
       raise TypeError('Cannot represent TFF type in TF because it contains '
                       f'partially named structures. Type: {type_spec}')
   return type_spec, False
Exemplo n.º 2
0
 def is_tesnor_or_struct_with_py_type(t: computation_types.Type) -> bool:
     return t.is_tensor() or t.is_struct_with_python()
Exemplo n.º 3
0
def transform_type_postorder(
    type_signature: computation_types.Type,
    transform_fn: Callable[[computation_types.Type],
                           Tuple[computation_types.Type, bool]]):
  """Walks type tree of `type_signature` postorder, calling `transform_fn`.

  Args:
    type_signature: Instance of `computation_types.Type` to transform
      recursively.
    transform_fn: Transformation function to apply to each node in the type tree
      of `type_signature`. Must be instance of Python function type.

  Returns:
    A possibly transformed version of `type_signature`, with each node in its
    tree the result of applying `transform_fn` to the corresponding node in
    `type_signature`.

  Raises:
    TypeError: If the types don't match the specification above.
  """
  py_typecheck.check_type(type_signature, computation_types.Type)
  py_typecheck.check_callable(transform_fn)
  if type_signature.is_federated():
    transformed_member, member_mutated = transform_type_postorder(
        type_signature.member, transform_fn)
    if member_mutated:
      type_signature = computation_types.FederatedType(transformed_member,
                                                       type_signature.placement,
                                                       type_signature.all_equal)
    type_signature, type_signature_mutated = transform_fn(type_signature)
    return type_signature, type_signature_mutated or member_mutated
  elif type_signature.is_sequence():
    transformed_element, element_mutated = transform_type_postorder(
        type_signature.element, transform_fn)
    if element_mutated:
      type_signature = computation_types.SequenceType(transformed_element)
    type_signature, type_signature_mutated = transform_fn(type_signature)
    return type_signature, type_signature_mutated or element_mutated
  elif type_signature.is_function():
    if type_signature.parameter is not None:
      transformed_parameter, parameter_mutated = transform_type_postorder(
          type_signature.parameter, transform_fn)
    else:
      transformed_parameter, parameter_mutated = (None, False)
    transformed_result, result_mutated = transform_type_postorder(
        type_signature.result, transform_fn)
    if parameter_mutated or result_mutated:
      type_signature = computation_types.FunctionType(transformed_parameter,
                                                      transformed_result)
    type_signature, type_signature_mutated = transform_fn(type_signature)
    return type_signature, (
        type_signature_mutated or parameter_mutated or result_mutated)
  elif type_signature.is_struct():
    elements = []
    elements_mutated = False
    for element in structure.iter_elements(type_signature):
      transformed_element, element_mutated = transform_type_postorder(
          element[1], transform_fn)
      elements_mutated = elements_mutated or element_mutated
      elements.append((element[0], transformed_element))
    if elements_mutated:
      if type_signature.is_struct_with_python():
        type_signature = computation_types.StructWithPythonType(
            elements, type_signature.python_container)
      else:
        type_signature = computation_types.StructType(elements)
    type_signature, type_signature_mutated = transform_fn(type_signature)
    return type_signature, type_signature_mutated or elements_mutated
  elif type_signature.is_abstract() or type_signature.is_placement(
  ) or type_signature.is_tensor():
    return transform_fn(type_signature)