def _concretize_abstract_types( abstract_type_spec: computation_types.Type, concrete_type_spec: computation_types.Type) -> computation_types.Type: """Recursive helper function to construct concrete type spec.""" if abstract_type_spec.is_abstract(): bound_type = bound_abstract_types.get(str(abstract_type_spec.label)) if bound_type: return bound_type else: bound_abstract_types[str(abstract_type_spec.label)] = concrete_type_spec return concrete_type_spec elif abstract_type_spec.is_tensor(): return abstract_type_spec elif abstract_type_spec.is_struct(): if not concrete_type_spec.is_struct(): raise TypeError(type_error_string) abstract_elements = structure.to_elements(abstract_type_spec) concrete_elements = structure.to_elements(concrete_type_spec) if len(abstract_elements) != len(concrete_elements): raise TypeError(type_error_string) concretized_tuple_elements = [] for k in range(len(abstract_elements)): if abstract_elements[k][0] != concrete_elements[k][0]: raise TypeError(type_error_string) concretized_tuple_elements.append( (abstract_elements[k][0], _concretize_abstract_types(abstract_elements[k][1], concrete_elements[k][1]))) return computation_types.StructType(concretized_tuple_elements) elif abstract_type_spec.is_sequence(): if not concrete_type_spec.is_sequence(): raise TypeError(type_error_string) return computation_types.SequenceType( _concretize_abstract_types(abstract_type_spec.element, concrete_type_spec.element)) elif abstract_type_spec.is_function(): if not concrete_type_spec.is_function(): raise TypeError(type_error_string) if abstract_type_spec.parameter is None: if concrete_type_spec.parameter is not None: return TypeError(type_error_string) concretized_param = None else: concretized_param = _concretize_abstract_types( abstract_type_spec.parameter, concrete_type_spec.parameter) concretized_result = _concretize_abstract_types(abstract_type_spec.result, concrete_type_spec.result) return computation_types.FunctionType(concretized_param, concretized_result) elif abstract_type_spec.is_placement(): if not concrete_type_spec.is_placement(): raise TypeError(type_error_string) return abstract_type_spec elif abstract_type_spec.is_federated(): if not concrete_type_spec.is_federated(): raise TypeError(type_error_string) new_member = _concretize_abstract_types(abstract_type_spec.member, concrete_type_spec.member) return computation_types.FederatedType(new_member, abstract_type_spec.placement, abstract_type_spec.all_equal) else: raise TypeError( 'Unexpected abstract typespec {}.'.format(abstract_type_spec))
def _check_helper(generic_type_member: computation_types.Type, concrete_type_member: computation_types.Type, defining: bool): """Recursive helper function.""" def _raise_structural(mismatch): raise MismatchedStructureError(concrete_type, generic_type, concrete_type_member, generic_type_member, mismatch) def _both_are(predicate): if predicate(generic_type_member): if predicate(concrete_type_member): return True else: _raise_structural('kind') else: return False if generic_type_member.is_abstract(): label = str(generic_type_member.label) if not defining: non_defining_usages[label].append(concrete_type_member) else: bound_type = type_bindings.get(label) if bound_type is not None: if not concrete_type_member.is_equivalent_to(bound_type): raise MismatchedConcreteTypesError(concrete_type, generic_type, label, bound_type, concrete_type_member) else: type_bindings[label] = concrete_type_member elif _both_are(lambda t: t.is_tensor()): if generic_type_member != concrete_type_member: _raise_structural('tensor types') elif _both_are(lambda t: t.is_placement()): if generic_type_member != concrete_type_member: _raise_structural('placements') elif _both_are(lambda t: t.is_struct()): generic_elements = structure.to_elements(generic_type_member) concrete_elements = structure.to_elements(concrete_type_member) if len(generic_elements) != len(concrete_elements): _raise_structural('length') for k in range(len(generic_elements)): if generic_elements[k][0] != concrete_elements[k][0]: _raise_structural('element names') _check_helper(generic_elements[k][1], concrete_elements[k][1], defining) elif _both_are(lambda t: t.is_sequence()): _check_helper(generic_type_member.element, concrete_type_member.element, defining) elif _both_are(lambda t: t.is_function()): if generic_type_member.parameter is None: if concrete_type_member.parameter is not None: _raise_structural('parameter') else: _check_helper(generic_type_member.parameter, concrete_type_member.parameter, not defining) _check_helper(generic_type_member.result, concrete_type_member.result, defining) elif _both_are(lambda t: t.is_federated()): if generic_type_member.placement != concrete_type_member.placement: _raise_structural('placement') if generic_type_member.all_equal != concrete_type_member.all_equal: _raise_structural('all equal') _check_helper(generic_type_member.member, concrete_type_member.member, defining) else: raise TypeError(f'Unexpected type kind {generic_type}.')
def transform_type_postorder( type_signature: computation_types.Type, transform_fn: Callable[[computation_types.Type], Tuple[computation_types.Type, bool]]): """Walks type tree of `type_signature` postorder, calling `transform_fn`. Args: type_signature: Instance of `computation_types.Type` to transform recursively. transform_fn: Transformation function to apply to each node in the type tree of `type_signature`. Must be instance of Python function type. Returns: A possibly transformed version of `type_signature`, with each node in its tree the result of applying `transform_fn` to the corresponding node in `type_signature`. Raises: TypeError: If the types don't match the specification above. """ py_typecheck.check_type(type_signature, computation_types.Type) py_typecheck.check_callable(transform_fn) if type_signature.is_federated(): transformed_member, member_mutated = transform_type_postorder( type_signature.member, transform_fn) if member_mutated: type_signature = computation_types.FederatedType(transformed_member, type_signature.placement, type_signature.all_equal) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or member_mutated elif type_signature.is_sequence(): transformed_element, element_mutated = transform_type_postorder( type_signature.element, transform_fn) if element_mutated: type_signature = computation_types.SequenceType(transformed_element) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or element_mutated elif type_signature.is_function(): if type_signature.parameter is not None: transformed_parameter, parameter_mutated = transform_type_postorder( type_signature.parameter, transform_fn) else: transformed_parameter, parameter_mutated = (None, False) transformed_result, result_mutated = transform_type_postorder( type_signature.result, transform_fn) if parameter_mutated or result_mutated: type_signature = computation_types.FunctionType(transformed_parameter, transformed_result) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, ( type_signature_mutated or parameter_mutated or result_mutated) elif type_signature.is_struct(): elements = [] elements_mutated = False for element in structure.iter_elements(type_signature): transformed_element, element_mutated = transform_type_postorder( element[1], transform_fn) elements_mutated = elements_mutated or element_mutated elements.append((element[0], transformed_element)) if elements_mutated: if type_signature.is_struct_with_python(): type_signature = computation_types.StructWithPythonType( elements, type_signature.python_container) else: type_signature = computation_types.StructType(elements) type_signature, type_signature_mutated = transform_fn(type_signature) return type_signature, type_signature_mutated or elements_mutated elif type_signature.is_abstract() or type_signature.is_placement( ) or type_signature.is_tensor(): return transform_fn(type_signature)