def reconcile_value_type_with_type_spec( value_type: computation_types.Type, type_spec: Optional[computation_types.Type]) -> computation_types.Type: """Reconciles a pair of types. Args: value_type: An instance of `tff.Type`. type_spec: An instance of `tff.Type`, or `None`. Returns: Either `value_type` if `type_spec` is `None`, or `type_spec` if `type_spec` is not `None` and rquivalent with `value_type`. Raises: TypeError: If arguments are of incompatible types. """ py_typecheck.check_type(value_type, computation_types.Type) if type_spec is not None: py_typecheck.check_type(value_type, computation_types.Type) if not value_type.is_equivalent_to(type_spec): raise TypeError('Expected a value of type {}, found {}.'.format( type_spec, value_type)) return type_spec if type_spec is not None else value_type
def _check_helper(generic_type_member: computation_types.Type, concrete_type_member: computation_types.Type, defining: bool): """Recursive helper function.""" def _raise_structural(mismatch): raise MismatchedStructureError(concrete_type, generic_type, concrete_type_member, generic_type_member, mismatch) def _both_are(predicate): if predicate(generic_type_member): if predicate(concrete_type_member): return True else: _raise_structural('kind') else: return False if generic_type_member.is_abstract(): label = str(generic_type_member.label) if not defining: non_defining_usages[label].append(concrete_type_member) else: bound_type = type_bindings.get(label) if bound_type is not None: if not concrete_type_member.is_equivalent_to(bound_type): raise MismatchedConcreteTypesError(concrete_type, generic_type, label, bound_type, concrete_type_member) else: type_bindings[label] = concrete_type_member elif _both_are(lambda t: t.is_tensor()): if generic_type_member != concrete_type_member: _raise_structural('tensor types') elif _both_are(lambda t: t.is_placement()): if generic_type_member != concrete_type_member: _raise_structural('placements') elif _both_are(lambda t: t.is_struct()): generic_elements = structure.to_elements(generic_type_member) concrete_elements = structure.to_elements(concrete_type_member) if len(generic_elements) != len(concrete_elements): _raise_structural('length') for k in range(len(generic_elements)): if generic_elements[k][0] != concrete_elements[k][0]: _raise_structural('element names') _check_helper(generic_elements[k][1], concrete_elements[k][1], defining) elif _both_are(lambda t: t.is_sequence()): _check_helper(generic_type_member.element, concrete_type_member.element, defining) elif _both_are(lambda t: t.is_function()): if generic_type_member.parameter is None: if concrete_type_member.parameter is not None: _raise_structural('parameter') else: _check_helper(generic_type_member.parameter, concrete_type_member.parameter, not defining) _check_helper(generic_type_member.result, concrete_type_member.result, defining) elif _both_are(lambda t: t.is_federated()): if generic_type_member.placement != concrete_type_member.placement: _raise_structural('placement') if generic_type_member.all_equal != concrete_type_member.all_equal: _raise_structural('all equal') _check_helper(generic_type_member.member, concrete_type_member.member, defining) else: raise TypeError(f'Unexpected type kind {generic_type}.')