Ejemplo n.º 1
0
def is_sum_compatible(type_spec: computation_types.Type) -> bool:
    """Determines if `type_spec` is a type that can be added to itself.

  Types that are sum-compatible are composed of scalars of numeric types,
  possibly packaged into nested named tuples, and possibly federated. Types
  that are sum-incompatible include sequences, functions, abstract types,
  and placements.

  Args:
    type_spec: A `computation_types.Type`.

  Returns:
    `True` iff `type_spec` is sum-compatible, `False` otherwise.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        return is_numeric_dtype(
            type_spec.dtype) and type_spec.shape.is_fully_defined()
    elif type_spec.is_struct():
        return all(
            is_sum_compatible(v)
            for _, v in structure.iter_elements(type_spec))
    elif type_spec.is_federated():
        return is_sum_compatible(type_spec.member)
    else:
        return False
Ejemplo n.º 2
0
 def ingest_value(
     self, value: Any, type_signature: computation_types.Type
 ) -> executor_value_base.ExecutorValue:
   if type_signature is not None:
     if type_signature.is_federated():
       self._check_strategy_compatible_with_placement(type_signature.placement)
     elif type_signature.is_function() and type_signature.result.is_federated(
     ):
       self._check_strategy_compatible_with_placement(
           type_signature.result.placement)
   return FederatedResolvingStrategyValue(value, type_signature)
Ejemplo n.º 3
0
def is_structure_of_integers(type_spec: computation_types.Type) -> bool:
  """Determines if `type_spec` is a structure of integers.

  Args:
    type_spec: A `computation_types.Type`.

  Returns:
    `True` iff `type_spec` is a structure of integers, otherwise `False`.
  """
  py_typecheck.check_type(type_spec, computation_types.Type)
  if type_spec.is_tensor():
    py_typecheck.check_type(type_spec.dtype, tf.DType)
    return type_spec.dtype.is_integer
  elif type_spec.is_struct():
    return all(
        is_structure_of_integers(v)
        for _, v in structure.iter_elements(type_spec))
  elif type_spec.is_federated():
    return is_structure_of_integers(type_spec.member)
  else:
    return False
Ejemplo n.º 4
0
def is_average_compatible(type_spec: computation_types.Type) -> bool:
  """Determines if `type_spec` can be averaged.

  Types that are average-compatible are composed of numeric tensor types,
  either floating-point or complex, possibly packaged into nested named tuples,
  and possibly federated.

  Args:
    type_spec: a `computation_types.Type`.

  Returns:
    `True` iff `type_spec` is average-compatible, `False` otherwise.
  """
  py_typecheck.check_type(type_spec, computation_types.Type)
  if type_spec.is_tensor():
    return type_spec.dtype.is_floating or type_spec.dtype.is_complex
  elif type_spec.is_struct():
    return all(
        is_average_compatible(v) for _, v in structure.iter_elements(type_spec))
  elif type_spec.is_federated():
    return is_average_compatible(type_spec.member)
  else:
    return False
Ejemplo n.º 5
0
def is_structure_of_floats(type_spec: computation_types.Type) -> bool:
    """Determines if `type_spec` is a structure of floats.

  Note that an empty `computation_types.StructType` will return `True`, as it
  does not contain any non-floating types.

  Args:
    type_spec: A `computation_types.Type`.

  Returns:
    `True` iff `type_spec` is a structure of floats, otherwise `False`.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        py_typecheck.check_type(type_spec.dtype, tf.DType)
        return type_spec.dtype.is_floating
    elif type_spec.is_struct():
        return all(
            is_structure_of_floats(v)
            for _, v in structure.iter_elements(type_spec))
    elif type_spec.is_federated():
        return is_structure_of_floats(type_spec.member)
    else:
        return False
Ejemplo n.º 6
0
def _federated_same_placement(x: computation_types.Type,
                              y: computation_types.Type) -> bool:
    return x.is_federated() and y.is_federated() and x.placement == y.placement
Ejemplo n.º 7
0
 def _concretize_abstract_types(
     abstract_type_spec: computation_types.Type,
     concrete_type_spec: computation_types.Type) -> computation_types.Type:
   """Recursive helper function to construct concrete type spec."""
   if abstract_type_spec.is_abstract():
     bound_type = bound_abstract_types.get(str(abstract_type_spec.label))
     if bound_type:
       return bound_type
     else:
       bound_abstract_types[str(abstract_type_spec.label)] = concrete_type_spec
       return concrete_type_spec
   elif abstract_type_spec.is_tensor():
     return abstract_type_spec
   elif abstract_type_spec.is_struct():
     if not concrete_type_spec.is_struct():
       raise TypeError(type_error_string)
     abstract_elements = structure.to_elements(abstract_type_spec)
     concrete_elements = structure.to_elements(concrete_type_spec)
     if len(abstract_elements) != len(concrete_elements):
       raise TypeError(type_error_string)
     concretized_tuple_elements = []
     for k in range(len(abstract_elements)):
       if abstract_elements[k][0] != concrete_elements[k][0]:
         raise TypeError(type_error_string)
       concretized_tuple_elements.append(
           (abstract_elements[k][0],
            _concretize_abstract_types(abstract_elements[k][1],
                                       concrete_elements[k][1])))
     return computation_types.StructType(concretized_tuple_elements)
   elif abstract_type_spec.is_sequence():
     if not concrete_type_spec.is_sequence():
       raise TypeError(type_error_string)
     return computation_types.SequenceType(
         _concretize_abstract_types(abstract_type_spec.element,
                                    concrete_type_spec.element))
   elif abstract_type_spec.is_function():
     if not concrete_type_spec.is_function():
       raise TypeError(type_error_string)
     if abstract_type_spec.parameter is None:
       if concrete_type_spec.parameter is not None:
         return TypeError(type_error_string)
       concretized_param = None
     else:
       concretized_param = _concretize_abstract_types(
           abstract_type_spec.parameter, concrete_type_spec.parameter)
     concretized_result = _concretize_abstract_types(abstract_type_spec.result,
                                                     concrete_type_spec.result)
     return computation_types.FunctionType(concretized_param,
                                           concretized_result)
   elif abstract_type_spec.is_placement():
     if not concrete_type_spec.is_placement():
       raise TypeError(type_error_string)
     return abstract_type_spec
   elif abstract_type_spec.is_federated():
     if not concrete_type_spec.is_federated():
       raise TypeError(type_error_string)
     new_member = _concretize_abstract_types(abstract_type_spec.member,
                                             concrete_type_spec.member)
     return computation_types.FederatedType(new_member,
                                            abstract_type_spec.placement,
                                            abstract_type_spec.all_equal)
   else:
     raise TypeError(
         'Unexpected abstract typespec {}.'.format(abstract_type_spec))
Ejemplo n.º 8
0
def transform_type_postorder(
    type_signature: computation_types.Type,
    transform_fn: Callable[[computation_types.Type],
                           Tuple[computation_types.Type, bool]]):
  """Walks type tree of `type_signature` postorder, calling `transform_fn`.

  Args:
    type_signature: Instance of `computation_types.Type` to transform
      recursively.
    transform_fn: Transformation function to apply to each node in the type tree
      of `type_signature`. Must be instance of Python function type.

  Returns:
    A possibly transformed version of `type_signature`, with each node in its
    tree the result of applying `transform_fn` to the corresponding node in
    `type_signature`.

  Raises:
    TypeError: If the types don't match the specification above.
  """
  py_typecheck.check_type(type_signature, computation_types.Type)
  py_typecheck.check_callable(transform_fn)
  if type_signature.is_federated():
    transformed_member, member_mutated = transform_type_postorder(
        type_signature.member, transform_fn)
    if member_mutated:
      type_signature = computation_types.FederatedType(transformed_member,
                                                       type_signature.placement,
                                                       type_signature.all_equal)
    type_signature, type_signature_mutated = transform_fn(type_signature)
    return type_signature, type_signature_mutated or member_mutated
  elif type_signature.is_sequence():
    transformed_element, element_mutated = transform_type_postorder(
        type_signature.element, transform_fn)
    if element_mutated:
      type_signature = computation_types.SequenceType(transformed_element)
    type_signature, type_signature_mutated = transform_fn(type_signature)
    return type_signature, type_signature_mutated or element_mutated
  elif type_signature.is_function():
    if type_signature.parameter is not None:
      transformed_parameter, parameter_mutated = transform_type_postorder(
          type_signature.parameter, transform_fn)
    else:
      transformed_parameter, parameter_mutated = (None, False)
    transformed_result, result_mutated = transform_type_postorder(
        type_signature.result, transform_fn)
    if parameter_mutated or result_mutated:
      type_signature = computation_types.FunctionType(transformed_parameter,
                                                      transformed_result)
    type_signature, type_signature_mutated = transform_fn(type_signature)
    return type_signature, (
        type_signature_mutated or parameter_mutated or result_mutated)
  elif type_signature.is_struct():
    elements = []
    elements_mutated = False
    for element in structure.iter_elements(type_signature):
      transformed_element, element_mutated = transform_type_postorder(
          element[1], transform_fn)
      elements_mutated = elements_mutated or element_mutated
      elements.append((element[0], transformed_element))
    if elements_mutated:
      if type_signature.is_struct_with_python():
        type_signature = computation_types.StructWithPythonType(
            elements, type_signature.python_container)
      else:
        type_signature = computation_types.StructType(elements)
    type_signature, type_signature_mutated = transform_fn(type_signature)
    return type_signature, type_signature_mutated or elements_mutated
  elif type_signature.is_abstract() or type_signature.is_placement(
  ) or type_signature.is_tensor():
    return transform_fn(type_signature)
Ejemplo n.º 9
0
 def _is_compatible(t: computation_types.Type) -> bool:
     return type_analysis.contains_only(
         t,
         lambda t: t.is_struct() or t.is_tensor() or t.is_federated())