示例#1
0
    def _check_top_level_compatibility_with_generic_operators(x, y, op_name):
        """Performs non-recursive check on the types of `x` and `y`."""
        x_compatible = type_utils.type_tree_contains_only(
            x.type_signature,
            (computation_types.NamedTupleType, computation_types.TensorType,
             computation_types.FederatedType))
        y_compatible = type_utils.type_tree_contains_only(
            y.type_signature,
            (computation_types.NamedTupleType, computation_types.TensorType,
             computation_types.FederatedType))

        def _make_bad_type_tree_string(index, type_spec):
            return (
                '{} is only implemented for pairs of '
                'arguments both containing only federated, tuple and '
                'tensor types; you have passed argument at index {} of type {} '
                .format(op_name, index, type_spec))

        if not (x_compatible and y_compatible):
            if y_compatible:
                raise TypeError(_make_bad_type_tree_string(
                    0, x.type_signature))
            elif x_compatible:
                raise TypeError(_make_bad_type_tree_string(
                    1, y.type_signature))
            else:
                raise TypeError(
                    '{} is only implemented for pairs of '
                    'arguments both containing only federated, tuple and '
                    'tensor types; both your arguments fail this condition. '
                    'You have passed first argument of type {} '
                    'and second argument of type {}.'.format(
                        op_name, x.type_signature, y.type_signature))

        top_level_mismatch_string = (
            '{} does not accept arguments of type {} and '
            '{}, as they are mismatched at the top level.'.format(
                op_name, x.type_signature, y.type_signature))
        if isinstance(x.type_signature, computation_types.FederatedType):
            if (not isinstance(y.type_signature,
                               computation_types.FederatedType)
                    or x.type_signature.placement != y.type_signature.placement
                    or not type_utils.is_binary_op_with_upcast_compatible_pair(
                        x.type_signature.member, y.type_signature.member)):
                raise TypeError(top_level_mismatch_string)
        if isinstance(x.type_signature, computation_types.NamedTupleType):
            if type_utils.is_binary_op_with_upcast_compatible_pair(
                    x.type_signature, y.type_signature):
                return None
            elif not isinstance(y.type_signature,
                                computation_types.NamedTupleType) or dir(
                                    x.type_signature) != dir(y.type_signature):
                raise TypeError(top_level_mismatch_string)
示例#2
0
def _check_generic_operator_type(type_spec):
  """Checks that `type_spec` can be the signature of args to a generic op."""
  if not type_utils.type_tree_contains_only(
      type_spec,
      (computation_types.FederatedType, computation_types.NamedTupleType,
       computation_types.TensorType)):
    raise TypeError(
        'Generic operators are only implemented for '
        'arguments both containing only federated, tuple and '
        'tensor types; you have passed an argument of type {} '.format(
            type_spec))
  if not (isinstance(type_spec, computation_types.NamedTupleType) and
          len(type_spec) == 2):
    raise TypeError(
        'We are trying to construct a generic operator declaring argument that '
        'is not a two-tuple, the type {}.'.format(type_spec))
  if not type_utils.is_binary_op_with_upcast_compatible_pair(
      type_spec[0], type_spec[1]):
    raise TypeError('The two-tuple you have passed in is incompatible with '
                    'upcasted binary operators. You have passed the tuple '
                    'type {}, which fails the check that the two members of '
                    'the tuple are either the same type, or the second is a '
                    'scalar with the same dtype as the leaves of the first. '
                    'See `type_utils.is_binary_op_with_upcast_compatible_pair` '
                    'for more details.'.format(type_spec))
示例#3
0
def construct_generic_constant(type_spec, scalar_value):
  """Creates constant for a combination of federated, tuple and tensor types.

  Args:
    type_spec: Instance of `computation_types.Type` containing only federated,
      tuple or tensor types for which we wish to construct a generic constant.
      May also be something convertible to a `computation_types.Type` via
      `computation_types.to_type`.
    scalar_value: The scalar value we wish this constant to have.

  Returns:
    Instance of `computation_building_blocks.ComputationBuildingBlock`
    representing `scalar_value` packed into `type_spec`.

  Raises:
    TypeError: If types don't match their specification in the args section.
    Notice validation of consistency of `type_spec` with `scalar_value` is not
    the rsponsibility of this function.
  """
  type_spec = computation_types.to_type(type_spec)
  py_typecheck.check_type(type_spec, computation_types.Type)
  inferred_scalar_value_type = type_utils.infer_type(scalar_value)
  if (not isinstance(inferred_scalar_value_type, computation_types.TensorType)
      or inferred_scalar_value_type.shape != tf.TensorShape(())):
    raise TypeError('Must pass a scalar value to '
                    '`construct_tensorflow_constant`; encountered a value '
                    '{}'.format(scalar_value))
  if not type_utils.type_tree_contains_only(
      type_spec,
      (computation_types.FederatedType, computation_types.NamedTupleType,
       computation_types.TensorType)):
    raise TypeError
  if type_utils.type_tree_contains_only(
      type_spec,
      (computation_types.NamedTupleType, computation_types.TensorType)):
    return computation_constructing_utils.construct_tensorflow_constant(
        type_spec, scalar_value)
  elif isinstance(type_spec, computation_types.FederatedType):
    unplaced_zero = computation_constructing_utils.construct_tensorflow_constant(
        type_spec.member, scalar_value)
    if type_spec.placement == placement_literals.CLIENTS:
      placement_fn_type = computation_types.FunctionType(
          type_spec.member,
          computation_types.FederatedType(
              type_spec.member, type_spec.placement, all_equal=True))
      placement_function = computation_building_blocks.Intrinsic(
          intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri, placement_fn_type)
    elif type_spec.placement == placement_literals.SERVER:
      placement_fn_type = computation_types.FunctionType(
          type_spec.member,
          computation_types.FederatedType(
              type_spec.member, type_spec.placement, all_equal=True))
      placement_function = computation_building_blocks.Intrinsic(
          intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri, placement_fn_type)
    return computation_building_blocks.Call(placement_function, unplaced_zero)
  elif isinstance(type_spec, computation_types.NamedTupleType):
    elements = []
    for k in range(len(type_spec)):
      elements.append(construct_generic_constant(type_spec[k], scalar_value))
    names = [name for name, _ in anonymous_tuple.to_elements(type_spec)]
    packed_elements = computation_building_blocks.Tuple(elements)
    named_tuple = computation_constructing_utils.create_named_tuple(
        packed_elements, names)
    return named_tuple
  else:
    raise ValueError(
        'The type_spec {} has slipped through all our '
        'generic constant cases, and failed to raise.'.format(type_spec))
示例#4
0
 def _only_tuple_or_tensor(value):
     return type_utils.type_tree_contains_only(
         value.type_signature, (computation_types.NamedTupleType,
                                computation_types.TensorType))