def is_sum_compatible(type_spec: computation_types.Type) -> bool: """Determines if `type_spec` is a type that can be added to itself. Types that are sum-compatible are composed of scalars of numeric types, possibly packaged into nested named tuples, and possibly federated. Types that are sum-incompatible include sequences, functions, abstract types, and placements. Args: type_spec: A `computation_types.Type`. Returns: `True` iff `type_spec` is sum-compatible, `False` otherwise. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return is_numeric_dtype( type_spec.dtype) and type_spec.shape.is_fully_defined() elif type_spec.is_struct(): return all( is_sum_compatible(v) for _, v in structure.iter_elements(type_spec)) elif type_spec.is_federated(): return is_sum_compatible(type_spec.member) else: return False
def ingest_value( self, value: Any, type_signature: computation_types.Type ) -> executor_value_base.ExecutorValue: if type_signature is not None: if type_signature.is_federated(): self._check_strategy_compatible_with_placement(type_signature.placement) elif type_signature.is_function() and type_signature.result.is_federated( ): self._check_strategy_compatible_with_placement( type_signature.result.placement) return FederatedResolvingStrategyValue(value, type_signature)
def is_structure_of_integers(type_spec: computation_types.Type) -> bool: """Determines if `type_spec` is a structure of integers. Args: type_spec: A `computation_types.Type`. Returns: `True` iff `type_spec` is a structure of integers, otherwise `False`. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): py_typecheck.check_type(type_spec.dtype, tf.DType) return type_spec.dtype.is_integer elif type_spec.is_struct(): return all( is_structure_of_integers(v) for _, v in structure.iter_elements(type_spec)) elif type_spec.is_federated(): return is_structure_of_integers(type_spec.member) else: return False
def is_average_compatible(type_spec: computation_types.Type) -> bool: """Determines if `type_spec` can be averaged. Types that are average-compatible are composed of numeric tensor types, either floating-point or complex, possibly packaged into nested named tuples, and possibly federated. Args: type_spec: a `computation_types.Type`. Returns: `True` iff `type_spec` is average-compatible, `False` otherwise. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return type_spec.dtype.is_floating or type_spec.dtype.is_complex elif type_spec.is_struct(): return all( is_average_compatible(v) for _, v in structure.iter_elements(type_spec)) elif type_spec.is_federated(): return is_average_compatible(type_spec.member) else: return False
def is_structure_of_floats(type_spec: computation_types.Type) -> bool: """Determines if `type_spec` is a structure of floats. Note that an empty `computation_types.StructType` will return `True`, as it does not contain any non-floating types. Args: type_spec: A `computation_types.Type`. Returns: `True` iff `type_spec` is a structure of floats, otherwise `False`. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): py_typecheck.check_type(type_spec.dtype, tf.DType) return type_spec.dtype.is_floating elif type_spec.is_struct(): return all( is_structure_of_floats(v) for _, v in structure.iter_elements(type_spec)) elif type_spec.is_federated(): return is_structure_of_floats(type_spec.member) else: return False
def _federated_same_placement(x: computation_types.Type, y: computation_types.Type) -> bool: return x.is_federated() and y.is_federated() and x.placement == y.placement
def _concretize_abstract_types( abstract_type_spec: computation_types.Type, concrete_type_spec: computation_types.Type) -> computation_types.Type: """Recursive helper function to construct concrete type spec.""" if abstract_type_spec.is_abstract(): bound_type = bound_abstract_types.get(str(abstract_type_spec.label)) if bound_type: return bound_type else: bound_abstract_types[str(abstract_type_spec.label)] = concrete_type_spec return concrete_type_spec elif abstract_type_spec.is_tensor(): return abstract_type_spec elif abstract_type_spec.is_struct(): if not concrete_type_spec.is_struct(): raise TypeError(type_error_string) abstract_elements = structure.to_elements(abstract_type_spec) concrete_elements = structure.to_elements(concrete_type_spec) if len(abstract_elements) != len(concrete_elements): raise TypeError(type_error_string) concretized_tuple_elements = [] for k in range(len(abstract_elements)): if abstract_elements[k][0] != concrete_elements[k][0]: raise TypeError(type_error_string) concretized_tuple_elements.append( (abstract_elements[k][0], _concretize_abstract_types(abstract_elements[k][1], concrete_elements[k][1]))) return computation_types.StructType(concretized_tuple_elements) elif abstract_type_spec.is_sequence(): if not concrete_type_spec.is_sequence(): raise TypeError(type_error_string) return computation_types.SequenceType( _concretize_abstract_types(abstract_type_spec.element, concrete_type_spec.element)) elif abstract_type_spec.is_function(): if not concrete_type_spec.is_function(): raise TypeError(type_error_string) if abstract_type_spec.parameter is None: if concrete_type_spec.parameter is not None: return TypeError(type_error_string) concretized_param = None else: concretized_param = _concretize_abstract_types( abstract_type_spec.parameter, concrete_type_spec.parameter) concretized_result = _concretize_abstract_types(abstract_type_spec.result, concrete_type_spec.result) return computation_types.FunctionType(concretized_param, concretized_result) elif abstract_type_spec.is_placement(): if not concrete_type_spec.is_placement(): raise TypeError(type_error_string) return abstract_type_spec elif abstract_type_spec.is_federated(): if not concrete_type_spec.is_federated(): raise TypeError(type_error_string) new_member = _concretize_abstract_types(abstract_type_spec.member, concrete_type_spec.member) return computation_types.FederatedType(new_member, abstract_type_spec.placement, abstract_type_spec.all_equal) else: raise TypeError( 'Unexpected abstract typespec {}.'.format(abstract_type_spec))
def transform_type_postorder( type_signature: computation_types.Type, transform_fn: Callable[[computation_types.Type], Tuple[computation_types.Type, bool]]): """Walks type tree of `type_signature` postorder, calling `transform_fn`. Args: type_signature: Instance of `computation_types.Type` to transform recursively. transform_fn: Transformation function to apply to each node in the type tree of `type_signature`. Must be instance of Python function type. Returns: A possibly transformed version of `type_signature`, with each node in its tree the result of applying `transform_fn` to the corresponding node in `type_signature`. Raises: TypeError: If the types don't match the specification above. """ py_typecheck.check_type(type_signature, computation_types.Type) py_typecheck.check_callable(transform_fn) if type_signature.is_federated(): transformed_member, member_mutated = transform_type_postorder( type_signature.member, transform_fn) if member_mutated: type_signature = computation_types.FederatedType(transformed_member, type_signature.placement, type_signature.all_equal) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or member_mutated elif type_signature.is_sequence(): transformed_element, element_mutated = transform_type_postorder( type_signature.element, transform_fn) if element_mutated: type_signature = computation_types.SequenceType(transformed_element) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or element_mutated elif type_signature.is_function(): if type_signature.parameter is not None: transformed_parameter, parameter_mutated = transform_type_postorder( type_signature.parameter, transform_fn) else: transformed_parameter, parameter_mutated = (None, False) transformed_result, result_mutated = transform_type_postorder( type_signature.result, transform_fn) if parameter_mutated or result_mutated: type_signature = computation_types.FunctionType(transformed_parameter, transformed_result) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, ( type_signature_mutated or parameter_mutated or result_mutated) elif type_signature.is_struct(): elements = [] elements_mutated = False for element in structure.iter_elements(type_signature): transformed_element, element_mutated = transform_type_postorder( element[1], transform_fn) elements_mutated = elements_mutated or element_mutated elements.append((element[0], transformed_element)) if elements_mutated: if type_signature.is_struct_with_python(): type_signature = computation_types.StructWithPythonType( elements, type_signature.python_container) else: type_signature = computation_types.StructType(elements) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or elements_mutated elif type_signature.is_abstract() or type_signature.is_placement( ) or type_signature.is_tensor(): return transform_fn(type_signature)
def _is_compatible(t: computation_types.Type) -> bool: return type_analysis.contains_only( t, lambda t: t.is_struct() or t.is_tensor() or t.is_federated())