Exemple #1
0
def reconcile_value_type_with_type_spec(
    value_type: computation_types.Type,
    type_spec: Optional[computation_types.Type]) -> computation_types.Type:
  """Reconciles a pair of types.

  Args:
    value_type: An instance of `tff.Type`.
    type_spec: An instance of `tff.Type`, or `None`.

  Returns:
    Either `value_type` if `type_spec` is `None`, or `type_spec` if `type_spec`
    is not `None` and rquivalent with `value_type`.

  Raises:
    TypeError: If arguments are of incompatible types.
  """
  py_typecheck.check_type(value_type, computation_types.Type)
  if type_spec is not None:
    py_typecheck.check_type(value_type, computation_types.Type)
    if not value_type.is_equivalent_to(type_spec):
      raise TypeError('Expected a value of type {}, found {}.'.format(
          type_spec, value_type))
  return type_spec if type_spec is not None else value_type
  def _check_helper(generic_type_member: computation_types.Type,
                    concrete_type_member: computation_types.Type,
                    defining: bool):
    """Recursive helper function."""

    def _raise_structural(mismatch):
      raise MismatchedStructureError(concrete_type, generic_type,
                                     concrete_type_member, generic_type_member,
                                     mismatch)

    def _both_are(predicate):
      if predicate(generic_type_member):
        if predicate(concrete_type_member):
          return True
        else:
          _raise_structural('kind')
      else:
        return False

    if generic_type_member.is_abstract():
      label = str(generic_type_member.label)
      if not defining:
        non_defining_usages[label].append(concrete_type_member)
      else:
        bound_type = type_bindings.get(label)
        if bound_type is not None:
          if not concrete_type_member.is_equivalent_to(bound_type):
            raise MismatchedConcreteTypesError(concrete_type, generic_type,
                                               label, bound_type,
                                               concrete_type_member)
        else:
          type_bindings[label] = concrete_type_member
    elif _both_are(lambda t: t.is_tensor()):
      if generic_type_member != concrete_type_member:
        _raise_structural('tensor types')
    elif _both_are(lambda t: t.is_placement()):
      if generic_type_member != concrete_type_member:
        _raise_structural('placements')
    elif _both_are(lambda t: t.is_struct()):
      generic_elements = structure.to_elements(generic_type_member)
      concrete_elements = structure.to_elements(concrete_type_member)
      if len(generic_elements) != len(concrete_elements):
        _raise_structural('length')
      for k in range(len(generic_elements)):
        if generic_elements[k][0] != concrete_elements[k][0]:
          _raise_structural('element names')
        _check_helper(generic_elements[k][1], concrete_elements[k][1], defining)
    elif _both_are(lambda t: t.is_sequence()):
      _check_helper(generic_type_member.element, concrete_type_member.element,
                    defining)
    elif _both_are(lambda t: t.is_function()):
      if generic_type_member.parameter is None:
        if concrete_type_member.parameter is not None:
          _raise_structural('parameter')
      else:
        _check_helper(generic_type_member.parameter,
                      concrete_type_member.parameter, not defining)
      _check_helper(generic_type_member.result, concrete_type_member.result,
                    defining)
    elif _both_are(lambda t: t.is_federated()):
      if generic_type_member.placement != concrete_type_member.placement:
        _raise_structural('placement')
      if generic_type_member.all_equal != concrete_type_member.all_equal:
        _raise_structural('all equal')
      _check_helper(generic_type_member.member, concrete_type_member.member,
                    defining)
    else:
      raise TypeError(f'Unexpected type kind {generic_type}.')