def _concretize_abstract_types( abstract_type_spec: computation_types.Type, concrete_type_spec: computation_types.Type) -> computation_types.Type: """Recursive helper function to construct concrete type spec.""" if abstract_type_spec.is_abstract(): bound_type = bound_abstract_types.get(str(abstract_type_spec.label)) if bound_type: return bound_type else: bound_abstract_types[str(abstract_type_spec.label)] = concrete_type_spec return concrete_type_spec elif abstract_type_spec.is_tensor(): return abstract_type_spec elif abstract_type_spec.is_struct(): if not concrete_type_spec.is_struct(): raise TypeError(type_error_string) abstract_elements = structure.to_elements(abstract_type_spec) concrete_elements = structure.to_elements(concrete_type_spec) if len(abstract_elements) != len(concrete_elements): raise TypeError(type_error_string) concretized_tuple_elements = [] for k in range(len(abstract_elements)): if abstract_elements[k][0] != concrete_elements[k][0]: raise TypeError(type_error_string) concretized_tuple_elements.append( (abstract_elements[k][0], _concretize_abstract_types(abstract_elements[k][1], concrete_elements[k][1]))) return computation_types.StructType(concretized_tuple_elements) elif abstract_type_spec.is_sequence(): if not concrete_type_spec.is_sequence(): raise TypeError(type_error_string) return computation_types.SequenceType( _concretize_abstract_types(abstract_type_spec.element, concrete_type_spec.element)) elif abstract_type_spec.is_function(): if not concrete_type_spec.is_function(): raise TypeError(type_error_string) if abstract_type_spec.parameter is None: if concrete_type_spec.parameter is not None: return TypeError(type_error_string) concretized_param = None else: concretized_param = _concretize_abstract_types( abstract_type_spec.parameter, concrete_type_spec.parameter) concretized_result = _concretize_abstract_types(abstract_type_spec.result, concrete_type_spec.result) return computation_types.FunctionType(concretized_param, concretized_result) elif abstract_type_spec.is_placement(): if not concrete_type_spec.is_placement(): raise TypeError(type_error_string) return abstract_type_spec elif abstract_type_spec.is_federated(): if not concrete_type_spec.is_federated(): raise TypeError(type_error_string) new_member = _concretize_abstract_types(abstract_type_spec.member, concrete_type_spec.member) return computation_types.FederatedType(new_member, abstract_type_spec.placement, abstract_type_spec.all_equal) else: raise TypeError( 'Unexpected abstract typespec {}.'.format(abstract_type_spec))
def transform_type_postorder( type_signature: computation_types.Type, transform_fn: Callable[[computation_types.Type], Tuple[computation_types.Type, bool]]): """Walks type tree of `type_signature` postorder, calling `transform_fn`. Args: type_signature: Instance of `computation_types.Type` to transform recursively. transform_fn: Transformation function to apply to each node in the type tree of `type_signature`. Must be instance of Python function type. Returns: A possibly transformed version of `type_signature`, with each node in its tree the result of applying `transform_fn` to the corresponding node in `type_signature`. Raises: TypeError: If the types don't match the specification above. """ py_typecheck.check_type(type_signature, computation_types.Type) py_typecheck.check_callable(transform_fn) if type_signature.is_federated(): transformed_member, member_mutated = transform_type_postorder( type_signature.member, transform_fn) if member_mutated: type_signature = computation_types.FederatedType(transformed_member, type_signature.placement, type_signature.all_equal) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or member_mutated elif type_signature.is_sequence(): transformed_element, element_mutated = transform_type_postorder( type_signature.element, transform_fn) if element_mutated: type_signature = computation_types.SequenceType(transformed_element) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or element_mutated elif type_signature.is_function(): if type_signature.parameter is not None: transformed_parameter, parameter_mutated = transform_type_postorder( type_signature.parameter, transform_fn) else: transformed_parameter, parameter_mutated = (None, False) transformed_result, result_mutated = transform_type_postorder( type_signature.result, transform_fn) if parameter_mutated or result_mutated: type_signature = computation_types.FunctionType(transformed_parameter, transformed_result) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, ( type_signature_mutated or parameter_mutated or result_mutated) elif type_signature.is_struct(): elements = [] elements_mutated = False for element in structure.iter_elements(type_signature): transformed_element, element_mutated = transform_type_postorder( element[1], transform_fn) elements_mutated = elements_mutated or element_mutated elements.append((element[0], transformed_element)) if elements_mutated: if type_signature.is_struct_with_python(): type_signature = computation_types.StructWithPythonType( elements, type_signature.python_container) else: type_signature = computation_types.StructType(elements) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or elements_mutated elif type_signature.is_abstract() or type_signature.is_placement( ) or type_signature.is_tensor(): return transform_fn(type_signature)