def construct_map_or_apply(func, arg):
  """Injects intrinsic to allow application of `func` to federated `arg`.

  Args:
    func: `value_base.Value` instance of non-federated type to be wrapped with
      intrinsic in order to call on `arg`.
    arg: `computation_building_blocks.ComputationBuildingBlock` instance of
      federated type for which to construct intrinsic in order to call `func` on
      `value`.

  Returns:
    Returns `value_base.Value` instance wrapping
      `computation_building_blocks.Intrinsic` which can call `func` on `arg`.
  """
  py_typecheck.check_type(func,
                          computation_building_blocks.ComputationBuildingBlock)
  py_typecheck.check_type(arg,
                          computation_building_blocks.ComputationBuildingBlock)
  py_typecheck.check_type(arg.type_signature, computation_types.FederatedType)
  result_type = computation_types.FederatedType(func.type_signature.result,
                                                arg.type_signature.placement,
                                                arg.type_signature.all_equal)
  if arg.type_signature.placement == placement_literals.SERVER:
    intrinsic = computation_building_blocks.Intrinsic(
        intrinsic_defs.FEDERATED_APPLY.uri,
        computation_types.FunctionType(
            [func.type_signature, arg.type_signature], result_type))
  elif arg.type_signature.placement == placement_literals.CLIENTS:
    intrinsic = computation_building_blocks.Intrinsic(
        intrinsic_defs.FEDERATED_MAP.uri,
        computation_types.FunctionType(
            [func.type_signature, arg.type_signature], result_type))
  return intrinsic
Ejemplo n.º 2
0
def construct_map_or_apply(fn, arg):
  """Injects intrinsic to allow application of `fn` to federated `arg`.

  Args:
    fn: An instance of `computation_building_blocks.ComputationBuildingBlock` of
      functional type to be wrapped with intrinsic in order to call on `arg`.
    arg: `computation_building_blocks.ComputationBuildingBlock` instance of
      federated type for which to construct intrinsic in order to call `fn` on
      `arg`. `member` of `type_signature` of `arg` must be assignable to
      `parameter` of `type_signature` of `fn`.

  Returns:
    Returns a `computation_building_blocks.Intrinsic` which can call
    `fn` on `arg`.
  """
  py_typecheck.check_type(fn,
                          computation_building_blocks.ComputationBuildingBlock)
  py_typecheck.check_type(arg,
                          computation_building_blocks.ComputationBuildingBlock)
  py_typecheck.check_type(fn.type_signature, computation_types.FunctionType)
  py_typecheck.check_type(arg.type_signature, computation_types.FederatedType)
  type_utils.check_assignable_from(fn.type_signature.parameter,
                                   arg.type_signature.member)
  if arg.type_signature.placement == placement_literals.SERVER:
    result_type = computation_types.FederatedType(fn.type_signature.result,
                                                  arg.type_signature.placement,
                                                  arg.type_signature.all_equal)
    intrinsic_type = computation_types.FunctionType(
        [fn.type_signature, arg.type_signature], result_type)
    intrinsic = computation_building_blocks.Intrinsic(
        intrinsic_defs.FEDERATED_APPLY.uri, intrinsic_type)
    tup = computation_building_blocks.Tuple((fn, arg))
    return computation_building_blocks.Call(intrinsic, tup)
  elif arg.type_signature.placement == placement_literals.CLIENTS:
    return create_federated_map(fn, arg)
Ejemplo n.º 3
0
 def test_propogates_dependence_up_through_lambda(self):
   dummy_intrinsic = computation_building_blocks.Intrinsic(
       'dummy_intrinsic', tf.int32)
   lam = computation_building_blocks.Lambda('x', tf.int32, dummy_intrinsic)
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       lam, dummy_intrinsic_predicate)
   self.assertIn(lam, dependent_nodes)
Ejemplo n.º 4
0
 def test_propogates_dependence_up_through_selection(self):
   dummy_intrinsic = computation_building_blocks.Intrinsic(
       'dummy_intrinsic', [tf.int32])
   selection = computation_building_blocks.Selection(dummy_intrinsic, index=0)
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       selection, dummy_intrinsic_predicate)
   self.assertIn(selection, dependent_nodes)
 def test_intrinsic_class_succeeds_simple_federated_map(self):
     simple_function = computation_types.FunctionType(tf.int32, tf.float32)
     federated_arg = computation_types.FederatedType(
         simple_function.parameter, placements.CLIENTS)
     federated_result = computation_types.FederatedType(
         simple_function.result, placements.CLIENTS)
     federated_map_concrete_type = computation_types.FunctionType(
         [simple_function, federated_arg], federated_result)
     concrete_federated_map = computation_building_blocks.Intrinsic(
         intrinsic_defs.FEDERATED_MAP.uri, federated_map_concrete_type)
     self.assertIsInstance(concrete_federated_map,
                           computation_building_blocks.Intrinsic)
     self.assertEqual(
         str(concrete_federated_map.type_signature),
         '(<(int32 -> float32),{int32}@CLIENTS> -> {float32}@CLIENTS)')
     self.assertEqual(concrete_federated_map.uri, 'federated_map')
     self.assertEqual(concrete_federated_map.compact_representation(),
                      'federated_map')
     concrete_federated_map_proto = concrete_federated_map.proto
     self.assertEqual(
         type_serialization.deserialize_type(
             concrete_federated_map_proto.type),
         concrete_federated_map.type_signature)
     self.assertEqual(
         concrete_federated_map_proto.WhichOneof('computation'),
         'intrinsic')
     self.assertEqual(concrete_federated_map_proto.intrinsic.uri,
                      concrete_federated_map.uri)
     self._serialize_deserialize_roundtrip_test(concrete_federated_map)
Ejemplo n.º 6
0
def _create_call_to_federated_map(fn, arg):
    r"""Creates a computation to call a federated map.

            Call
           /    \
  Intrinsic      Tuple
                /     \
     Computation       Computation

  Args:
    fn: An instance of a functional
      `computation_building_blocks.ComputationBuildingBlock` to use as the map
      function.
    arg: An instance of `computation_building_blocks.ComputationBuildingBlock`
      to use as the map argument.

  Returns:
    An instance of `computation_building_blocks.Call` wrapping the federated map
    computation.
  """
    py_typecheck.check_type(
        fn, computation_building_blocks.ComputationBuildingBlock)
    py_typecheck.check_type(
        arg, computation_building_blocks.ComputationBuildingBlock)
    federated_type = computation_types.FederatedType(fn.type_signature.result,
                                                     placements.CLIENTS)
    function_type = computation_types.FunctionType(
        [fn.type_signature, arg.type_signature], federated_type)
    intrinsic = computation_building_blocks.Intrinsic(
        intrinsic_defs.FEDERATED_MAP.uri, function_type)
    tup = computation_building_blocks.Tuple((fn, arg))
    return computation_building_blocks.Call(intrinsic, tup)
Ejemplo n.º 7
0
def create_sequence_map(fn, arg):
    r"""Creates a called sequence map.

            Call
           /    \
  Intrinsic      Tuple
                 |
                 [Comp, Comp]

  Args:
    fn: A `computation_building_blocks.ComputationBuildingBlock` to use as the
      function.
    arg: A `computation_building_blocks.ComputationBuildingBlock` to use as the
      argument.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
  """
    py_typecheck.check_type(
        fn, computation_building_blocks.ComputationBuildingBlock)
    py_typecheck.check_type(
        arg, computation_building_blocks.ComputationBuildingBlock)
    result_type = computation_types.SequenceType(fn.type_signature.result)
    intrinsic_type = computation_types.FunctionType(
        (fn.type_signature, arg.type_signature), result_type)
    intrinsic = computation_building_blocks.Intrinsic(
        intrinsic_defs.SEQUENCE_MAP.uri, intrinsic_type)
    values = computation_building_blocks.Tuple((fn, arg))
    return computation_building_blocks.Call(intrinsic, values)
Ejemplo n.º 8
0
  def federated_sum(self, value):
    """Implements `federated_sum` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
    value = value_impl.to_value(value, None, self._context_stack)
    type_utils.check_federated_value_placement(value, placements.CLIENTS,
                                               'value to be summed')

    if not type_utils.is_sum_compatible(value.type_signature):
      raise TypeError(
          'The value type {} is not compatible with the sum operator.'.format(
              str(value.type_signature)))

    # TODO(b/113112108): Replace this as noted above.
    result_type = computation_types.FederatedType(value.type_signature.member,
                                                  placements.SERVER, True)
    intrinsic = value_impl.ValueImpl(
        computation_building_blocks.Intrinsic(
            intrinsic_defs.FEDERATED_SUM.uri,
            computation_types.FunctionType(value.type_signature, result_type)),
        self._context_stack)
    return intrinsic(value)
Ejemplo n.º 9
0
    def sequence_reduce(self, value, zero, op):
        """Implements `sequence_reduce` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.
      zero: As in `api/intrinsics.py`.
      op: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
        value = value_impl.to_value(value, None, self._context_stack)
        zero = value_impl.to_value(zero, None, self._context_stack)
        op = value_impl.to_value(op, None, self._context_stack)
        if isinstance(value.type_signature, computation_types.SequenceType):
            element_type = value.type_signature.element
        else:
            py_typecheck.check_type(value.type_signature,
                                    computation_types.FederatedType)
            py_typecheck.check_type(value.type_signature.member,
                                    computation_types.SequenceType)
            element_type = value.type_signature.member.element
        op_type_expected = type_constructors.reduction_op(
            zero.type_signature, element_type)
        if not type_utils.is_assignable_from(op_type_expected,
                                             op.type_signature):
            raise TypeError('Expected an operator of type {}, got {}.'.format(
                op_type_expected, op.type_signature))

        value = value_impl.ValueImpl.get_comp(value)
        zero = value_impl.ValueImpl.get_comp(zero)
        op = value_impl.ValueImpl.get_comp(op)
        if isinstance(value.type_signature, computation_types.SequenceType):
            return computation_constructing_utils.create_sequence_reduce(
                value, zero, op)
        else:
            value_type = computation_types.SequenceType(element_type)
            intrinsic_type = computation_types.FunctionType((
                value_type,
                zero.type_signature,
                op.type_signature,
            ), op.type_signature.result)
            intrinsic = computation_building_blocks.Intrinsic(
                intrinsic_defs.SEQUENCE_REDUCE.uri, intrinsic_type)
            ref = computation_building_blocks.Reference('arg', value_type)
            tup = computation_building_blocks.Tuple((ref, zero, op))
            call = computation_building_blocks.Call(intrinsic, tup)
            fn = computation_building_blocks.Lambda(ref.name,
                                                    ref.type_signature, call)
            fn_impl = value_impl.ValueImpl(fn, self._context_stack)
            if value.type_signature.placement is placements.SERVER:
                return self.federated_apply(fn_impl, value)
            elif value.type_signature.placement is placements.CLIENTS:
                return self.federated_map(fn_impl, value)
            else:
                raise TypeError('Unsupported placement {}.'.format(
                    value.type_signature.placement))
Ejemplo n.º 10
0
  def sequence_reduce(self, value, zero, op):
    """Implements `sequence_reduce` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.
      zero: As in `api/intrinsics.py`.
      op: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
    value = value_impl.to_value(value, None, self._context_stack)
    zero = value_impl.to_value(zero, None, self._context_stack)
    op = value_impl.to_value(op, None, self._context_stack)
    if isinstance(value.type_signature, computation_types.SequenceType):
      element_type = value.type_signature.element
    else:
      py_typecheck.check_type(value.type_signature,
                              computation_types.FederatedType)
      py_typecheck.check_type(value.type_signature.member,
                              computation_types.SequenceType)
      element_type = value.type_signature.member.element
    op_type_expected = type_constructors.reduction_op(zero.type_signature,
                                                      element_type)
    if not type_utils.is_assignable_from(op_type_expected, op.type_signature):
      raise TypeError('Expected an operator of type {}, got {}.'.format(
          str(op_type_expected), str(op.type_signature)))
    sequence_reduce_building_block = computation_building_blocks.Intrinsic(
        intrinsic_defs.SEQUENCE_REDUCE.uri,
        computation_types.FunctionType([
            computation_types.SequenceType(element_type), zero.type_signature,
            op.type_signature
        ], zero.type_signature))
    if isinstance(value.type_signature, computation_types.SequenceType):
      sequence_reduce_intrinsic = value_impl.ValueImpl(
          sequence_reduce_building_block, self._context_stack)
      return sequence_reduce_intrinsic(value, zero, op)
    else:
      federated_mapping_fn_building_block = computation_building_blocks.Lambda(
          'arg', computation_types.SequenceType(element_type),
          computation_building_blocks.Call(
              sequence_reduce_building_block,
              computation_building_blocks.Tuple([
                  computation_building_blocks.Reference(
                      'arg', computation_types.SequenceType(element_type)),
                  value_impl.ValueImpl.get_comp(zero),
                  value_impl.ValueImpl.get_comp(op)
              ])))
      federated_mapping_fn = value_impl.ValueImpl(
          federated_mapping_fn_building_block, self._context_stack)
      if value.type_signature.placement is placements.SERVER:
        return self.federated_apply(federated_mapping_fn, value)
      elif value.type_signature.placement is placements.CLIENTS:
        return self.federated_map(federated_mapping_fn, value)
      else:
        raise TypeError('Unsupported placement {}.'.format(
            str(value.type_signature.placement)))
Ejemplo n.º 11
0
  def federated_broadcast(self, value):
    """Implements `federated_broadcast` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
    value = value_impl.to_value(value, None, self._context_stack)
    type_utils.check_federated_value_placement(value, placements.SERVER,
                                               'value to be broadcasted')

    if not value.type_signature.all_equal:
      raise TypeError('The broadcasted value should be equal at all locations.')

    # TODO(b/113112108): Replace this hand-crafted logic here and below with
    # a call to a helper function that handles it in a uniform manner after
    # implementing support for correctly typechecking federated template types
    # and instantiating template types on concrete arguments.
    result_type = computation_types.FederatedType(value.type_signature.member,
                                                  placements.CLIENTS, True)
    intrinsic = value_impl.ValueImpl(
        computation_building_blocks.Intrinsic(
            intrinsic_defs.FEDERATED_BROADCAST.uri,
            computation_types.FunctionType(value.type_signature, result_type)),
        self._context_stack)
    return intrinsic(value)
Ejemplo n.º 12
0
def _create_lambda_to_add_one(dtype):
    r"""Creates a computation to add `1` to an argument.

  Lambda
        \
         Call
        /    \
  Intrinsic   Tuple
             /     \
    Reference       Computation

  Args:
    dtype: The type of the argument.

  Returns:
    An instance of `computation_building_blocks.Lambda` wrapping a function that
    adds 1 to an argument.
  """
    if isinstance(dtype, computation_types.TensorType):
        dtype = dtype.dtype
    py_typecheck.check_type(dtype, tf.dtypes.DType)
    function_type = computation_types.FunctionType([dtype, dtype], dtype)
    intrinsic = computation_building_blocks.Intrinsic(
        intrinsic_defs.GENERIC_PLUS.uri, function_type)
    arg = computation_building_blocks.Reference('arg', dtype)
    constant = _create_call_to_py_fn(lambda: tf.cast(tf.constant(1), dtype))
    tup = computation_building_blocks.Tuple([arg, constant])
    call = computation_building_blocks.Call(intrinsic, tup)
    return computation_building_blocks.Lambda(arg.name, arg.type_signature,
                                              call)
Ejemplo n.º 13
0
def create_federated_map(fn, arg):
    r"""Creates a called federated map.

            Call
           /    \
  Intrinsic      Tuple
                 |
                 [Comp, Comp]

  Args:
    fn: A `computation_building_blocks.ComputationBuildingBlock` to use as the
      function.
    arg: A `computation_building_blocks.ComputationBuildingBlock` to use as the
      argument.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
  """
    py_typecheck.check_type(
        fn, computation_building_blocks.ComputationBuildingBlock)
    py_typecheck.check_type(
        arg, computation_building_blocks.ComputationBuildingBlock)
    result_type = computation_types.FederatedType(fn.type_signature.result,
                                                  placement_literals.CLIENTS,
                                                  False)
    intrinsic_type = computation_types.FunctionType(
        (fn.type_signature, arg.type_signature), result_type)
    intrinsic = computation_building_blocks.Intrinsic(
        intrinsic_defs.FEDERATED_MAP.uri, intrinsic_type)
    values = computation_building_blocks.Tuple((fn, arg))
    return computation_building_blocks.Call(intrinsic, values)
Ejemplo n.º 14
0
def create_federated_sum(value):
    r"""Creates a called federated sum.

            Call
           /    \
  Intrinsic      Comp

  Args:
    value: A `computation_building_blocks.ComputationBuildingBlock` to use as
      the value.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
  """
    py_typecheck.check_type(
        value, computation_building_blocks.ComputationBuildingBlock)
    result_type = computation_types.FederatedType(value.type_signature.member,
                                                  placement_literals.SERVER,
                                                  True)
    intrinsic_type = computation_types.FunctionType(value.type_signature,
                                                    result_type)
    intrinsic = computation_building_blocks.Intrinsic(
        intrinsic_defs.FEDERATED_SUM.uri, intrinsic_type)
    return computation_building_blocks.Call(intrinsic, value)
Ejemplo n.º 15
0
def _create_lambda_to_dummy_intrinsic(type_spec, uri='dummy'):
    r"""Creates a lambda to call a dummy intrinsic.

  Lambda
        \
         Call
        /    \
  Intrinsic   Ref(arg)

  Args:
    type_spec: The type of the argument.
    uri: The URI of the intrinsic.

  Returns:
    A `computation_building_blocks.Lambda`.

  Raises:
    TypeError: If `type_spec` is not a `tf.dtypes.DType`.
  """
    py_typecheck.check_type(type_spec, tf.dtypes.DType)
    intrinsic_type = computation_types.FunctionType(type_spec, type_spec)
    intrinsic = computation_building_blocks.Intrinsic(uri, intrinsic_type)
    arg = computation_building_blocks.Reference('arg', type_spec)
    call = computation_building_blocks.Call(intrinsic, arg)
    return computation_building_blocks.Lambda(arg.name, arg.type_signature,
                                              call)
Ejemplo n.º 16
0
def create_federated_value(value, placement):
    r"""Creates a called federated value.

            Call
           /    \
  Intrinsic      Comp

  Args:
    value: A `computation_building_blocks.ComputationBuildingBlock` to use as
      the value.
    placement: A `placement_literals.PlacementLiteral` to use as the placement.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
  """
    py_typecheck.check_type(
        value, computation_building_blocks.ComputationBuildingBlock)
    if placement is placement_literals.CLIENTS:
        uri = intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri
    elif placement is placement_literals.SERVER:
        uri = intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri
    else:
        raise TypeError('Unsupported placement {}.'.format(placement))
    result_type = computation_types.FederatedType(value.type_signature,
                                                  placement, True)
    intrinsic_type = computation_types.FunctionType(value.type_signature,
                                                    result_type)
    intrinsic = computation_building_blocks.Intrinsic(uri, intrinsic_type)
    return computation_building_blocks.Call(intrinsic, value)
Ejemplo n.º 17
0
  def federated_value(self, value, placement):
    """Implements `federated_value` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.
      placement: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
    value = value_impl.to_value(value, None, self._context_stack)

    # TODO(b/113112108): Verify that neither the value, nor any of its parts
    # are of a federated type.

    if placement is placements.CLIENTS:
      uri = intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri
    elif placement is placements.SERVER:
      uri = intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri
    else:
      raise TypeError('The placement must be either CLIENTS or SERVER.')

    result_type = computation_types.FederatedType(value.type_signature,
                                                  placement, True)
    intrinsic = value_impl.ValueImpl(
        computation_building_blocks.Intrinsic(
            uri,
            computation_types.FunctionType(value.type_signature, result_type)),
        self._context_stack)
    return intrinsic(value)
Ejemplo n.º 18
0
  def federated_collect(self, value):
    """Implements `federated_collect` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
    value = value_impl.to_value(value, None, self._context_stack)
    type_utils.check_federated_value_placement(value, placements.CLIENTS,
                                               'value to be collected')

    result_type = computation_types.FederatedType(
        computation_types.SequenceType(value.type_signature.member),
        placements.SERVER, True)
    intrinsic = value_impl.ValueImpl(
        computation_building_blocks.Intrinsic(
            intrinsic_defs.FEDERATED_COLLECT.uri,
            computation_types.FunctionType(value.type_signature, result_type)),
        self._context_stack)
    return intrinsic(value)
def _create_zip_two_values(value):
    r"""Creates a called federated zip with two values.

            Call
           /    \
  Intrinsic      Tuple
                 |
                 [Comp1, Comp2]

  Notice that this function will drop any names associated to the two-tuple it
  is processing. This is necessary due to the type signature of the
  underlying federated zip intrinsic, `<T@P,U@P>-><T,U>@P`. Keeping names
  here would violate this type signature. The names are cached at a higher
  level than this function, and appended to the resulting tuple in a single
  call to `federated_map` or `federated_apply` before the resulting structure
  is sent back to the caller.

  Args:
    value: A `computation_building_blocks.ComputationBuildingBlock` with a
      `type_signature` of type `computation_types.NamedTupleType` containing
      exactly two elements.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
    ValueError: If `value` does not contain exactly two elements.
  """
    py_typecheck.check_type(
        value, computation_building_blocks.ComputationBuildingBlock)
    named_type_signatures = anonymous_tuple.to_elements(value.type_signature)
    length = len(named_type_signatures)
    if length != 2:
        raise ValueError(
            'Expected a value with exactly two elements, received {} elements.'
            .format(named_type_signatures))
    placement = value[0].type_signature.placement
    if placement is placement_literals.CLIENTS:
        uri = intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS.uri
        all_equal = False
    elif placement is placement_literals.SERVER:
        uri = intrinsic_defs.FEDERATED_ZIP_AT_SERVER.uri
        all_equal = True
    else:
        raise TypeError('Unsupported placement {}.'.format(placement))
    elements = []
    for _, type_signature in named_type_signatures:
        federated_type = computation_types.FederatedType(
            type_signature.member, placement, all_equal)
        elements.append((None, federated_type))
    parameter_type = computation_types.NamedTupleType(elements)
    result_type = computation_types.FederatedType(
        [(None, e.member) for _, e in named_type_signatures], placement,
        all_equal)
    intrinsic_type = computation_types.FunctionType(parameter_type,
                                                    result_type)
    intrinsic = computation_building_blocks.Intrinsic(uri, intrinsic_type)
    return computation_building_blocks.Call(intrinsic, value)
Ejemplo n.º 20
0
def zip_two_tuple(input_val, context_stack):
    """Helper function to perform 2-tuple at a time zipping.

  Takes 2-tuple of federated values and returns federated 2-tuple of values.

  Args:
    input_val: 2-tuple TFF `Value` of `NamedTuple` type, whose elements must be
      `FederatedTypes` with the same placement.
    context_stack: The context stack to use, as in `impl.value_impl.to_value`.

  Returns:
    TFF `Value` of `FederatedType` with member of 2-tuple `NamedTuple` type.
  """
    py_typecheck.check_type(input_val, value_base.Value)
    py_typecheck.check_type(input_val.type_signature,
                            computation_types.NamedTupleType)
    py_typecheck.check_type(input_val[0].type_signature,
                            computation_types.FederatedType)

    zip_uris = {
        placements.CLIENTS: intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS.uri,
        placements.SERVER: intrinsic_defs.FEDERATED_ZIP_AT_SERVER.uri,
    }
    zip_all_equal = {
        placements.CLIENTS: False,
        placements.SERVER: True,
    }
    output_placement = input_val[0].type_signature.placement
    if output_placement not in zip_uris:
        raise TypeError(
            'The argument must have components placed at SERVER or '
            'CLIENTS')
    output_all_equal_bit = zip_all_equal[output_placement]
    for elem in input_val:
        type_utils.check_federated_value_placement(elem, output_placement)
    num_elements = len(anonymous_tuple.to_elements(input_val.type_signature))
    if num_elements != 2:
        raise ValueError('The argument of zip_two_tuple must be a 2-tuple, '
                         'not an {}-tuple'.format(num_elements))
    result_type = computation_types.FederatedType(
        [(name, e.member)
         for name, e in anonymous_tuple.to_elements(input_val.type_signature)],
        output_placement, output_all_equal_bit)

    def _adjust_all_equal_bit(x):
        return computation_types.FederatedType(x.member, x.placement,
                                               output_all_equal_bit)

    adjusted_input_type = computation_types.NamedTupleType([
        (k, _adjust_all_equal_bit(v)) if k else _adjust_all_equal_bit(v)
        for k, v in anonymous_tuple.to_elements(input_val.type_signature)
    ])

    intrinsic = value_impl.ValueImpl(
        computation_building_blocks.Intrinsic(
            zip_uris[output_placement],
            computation_types.FunctionType(adjusted_input_type, result_type)),
        context_stack)
    return intrinsic(input_val)
Ejemplo n.º 21
0
 def test_passes_with_federated_map(self):
   intrinsic = computation_building_blocks.Intrinsic(
       intrinsic_defs.FEDERATED_MAP.uri,
       computation_types.FunctionType([
           computation_types.FunctionType(tf.int32, tf.float32),
           computation_types.FederatedType(tf.int32, placements.CLIENTS)
       ], computation_types.FederatedType(tf.float32, placements.CLIENTS)))
   tree_analysis.check_intrinsics_whitelisted_for_reduction(intrinsic)
 def test_returns_string_for_intrinsic(self):
     comp = computation_building_blocks.Intrinsic('intrinsic', tf.int32)
     compact_string = comp.compact_representation()
     self.assertEqual(compact_string, 'intrinsic')
     formatted_string = comp.formatted_representation()
     self.assertEqual(formatted_string, 'intrinsic')
     structural_string = comp.structural_representation()
     self.assertEqual(structural_string, 'intrinsic')
Ejemplo n.º 23
0
    def _transform(comp):
        """Returns a new transformed computation or `comp`."""
        if not _should_transform(comp):
            return comp, False

        def _create_block_to_chained_calls(comps):
            r"""Constructs a transformed block computation from `comps`.

                     Block
                    /     \
          [fn=Tuple]       Lambda(arg)
              |                       \
      [Comp(y), Comp(x)]               Call
                                      /    \
                                Sel(1)      Call
                               /           /    \
                        Ref(fn)      Sel(0)      Ref(arg)
                                    /
                             Ref(fn)

      (let fn=<y, x> in (arg -> fn[1](fn[0](arg)))

      Args:
        comps: a Python list of computations.

      Returns:
        A `computation_building_blocks.Block`.
      """
            functions = computation_building_blocks.Tuple(comps)
            fn_ref = computation_building_blocks.Reference(
                'fn', functions.type_signature)
            arg_type = comps[0].type_signature.parameter
            arg_ref = computation_building_blocks.Reference('arg', arg_type)
            arg = arg_ref
            for index, _ in enumerate(comps):
                fn_sel = computation_building_blocks.Selection(fn_ref,
                                                               index=index)
                call = computation_building_blocks.Call(fn_sel, arg)
                arg = call
            lam = computation_building_blocks.Lambda(arg_ref.name,
                                                     arg_ref.type_signature,
                                                     call)
            return computation_building_blocks.Block([('fn', functions)], lam)

        block = _create_block_to_chained_calls((
            comp.argument[1].argument[0],
            comp.argument[0],
        ))
        arg = computation_building_blocks.Tuple([
            block,
            comp.argument[1].argument[1],
        ])
        intrinsic_type = computation_types.FunctionType(
            arg.type_signature, comp.function.type_signature.result)
        intrinsic = computation_building_blocks.Intrinsic(
            comp.function.uri, intrinsic_type)
        transformed_comp = computation_building_blocks.Call(intrinsic, arg)
        return transformed_comp, True
Ejemplo n.º 24
0
 def test_propogates_dependence_up_through_block_locals(self):
   dummy_intrinsic = computation_building_blocks.Intrinsic(
       'dummy_intrinsic', tf.int32)
   integer_reference = computation_building_blocks.Reference('int', tf.int32)
   block = computation_building_blocks.Block([('x', dummy_intrinsic)],
                                             integer_reference)
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       block, dummy_intrinsic_predicate)
   self.assertIn(block, dependent_nodes)
Ejemplo n.º 25
0
 def test_propogates_dependence_up_through_tuple(self):
   dummy_intrinsic = computation_building_blocks.Intrinsic(
       'dummy_intrinsic', tf.int32)
   integer_reference = computation_building_blocks.Reference('int', tf.int32)
   tup = computation_building_blocks.Tuple(
       [integer_reference, dummy_intrinsic])
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       tup, dummy_intrinsic_predicate)
   self.assertIn(tup, dependent_nodes)
Ejemplo n.º 26
0
    def test_raises_with_federated_mean(self):
        intrinsic = computation_building_blocks.Intrinsic(
            intrinsic_defs.FEDERATED_MEAN.uri,
            computation_types.FunctionType(
                computation_types.FederatedType(tf.int32, placements.CLIENTS),
                computation_types.FederatedType(tf.int32, placements.SERVER)))

        with self.assertRaisesRegex(ValueError,
                                    intrinsic.compact_representation()):
            tree_analysis.check_intrinsics_whitelisted_for_reduction(intrinsic)
Ejemplo n.º 27
0
  def federated_map(self, mapping_fn, value):
    """Implements `federated_map` as defined in `api/intrinsics.py`.

    Args:
      mapping_fn: As in `api/intrinsics.py`.
      value: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """

    # TODO(b/113112108): Possibly lift the restriction that the mapped value
    # must be placed at the clients after adding support for placement labels
    # in the federated types, and expanding the type specification of the
    # intrinsic this is based on to work with federated values of arbitrary
    # placement.

    value = value_impl.to_value(value, None, self._context_stack)
    if isinstance(value.type_signature, computation_types.NamedTupleType):
      if len(anonymous_tuple.to_elements(value.type_signature)) >= 2:
        # We've been passed a value which the user expects to be zipped.
        value = self.federated_zip(value)
    type_utils.check_federated_value_placement(value, placements.CLIENTS,
                                               'value to be mapped')

    # TODO(b/113112108): Add support for polymorphic templates auto-instantiated
    # here based on the actual type of the argument.
    mapping_fn = value_impl.to_value(mapping_fn, None, self._context_stack)

    py_typecheck.check_type(mapping_fn, value_base.Value)
    py_typecheck.check_type(mapping_fn.type_signature,
                            computation_types.FunctionType)
    if not type_utils.is_assignable_from(mapping_fn.type_signature.parameter,
                                         value.type_signature.member):
      raise TypeError(
          'The mapping function expects a parameter of type {}, but member '
          'constituents of the mapped value are of incompatible type {}.'
          .format(
              str(mapping_fn.type_signature.parameter),
              str(value.type_signature.member)))

    # TODO(b/113112108): Replace this as noted above.
    result_type = computation_types.FederatedType(
        mapping_fn.type_signature.result, placements.CLIENTS,
        value.type_signature.all_equal)
    intrinsic = value_impl.ValueImpl(
        computation_building_blocks.Intrinsic(
            intrinsic_defs.FEDERATED_MAP.uri,
            computation_types.FunctionType(
                [mapping_fn.type_signature, value.type_signature],
                result_type)), self._context_stack)
    return intrinsic(mapping_fn, value)
Ejemplo n.º 28
0
def create_federated_mean(value, weight):
    r"""Creates a called federated mean.

            Call
           /    \
  Intrinsic      Tuple
                 |
                 [Comp, Comp]

  Args:
    value: A `computation_building_blocks.ComputationBuildingBlock` to use as
      the value.
    weight: A `computation_building_blocks.ComputationBuildingBlock` to use as
      the weight or `None`.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
  """
    py_typecheck.check_type(
        value, computation_building_blocks.ComputationBuildingBlock)
    if weight is not None:
        py_typecheck.check_type(
            weight, computation_building_blocks.ComputationBuildingBlock)
    result_type = computation_types.FederatedType(value.type_signature.member,
                                                  placement_literals.SERVER,
                                                  True)
    if weight is not None:
        intrinsic_type = computation_types.FunctionType(
            (value.type_signature, weight.type_signature), result_type)
        intrinsic = computation_building_blocks.Intrinsic(
            intrinsic_defs.FEDERATED_WEIGHTED_MEAN.uri, intrinsic_type)
        values = computation_building_blocks.Tuple((value, weight))
        return computation_building_blocks.Call(intrinsic, values)
    else:
        intrinsic_type = computation_types.FunctionType(
            value.type_signature, result_type)
        intrinsic = computation_building_blocks.Intrinsic(
            intrinsic_defs.FEDERATED_MEAN.uri, intrinsic_type)
        return computation_building_blocks.Call(intrinsic, value)
Ejemplo n.º 29
0
 def test_propogates_dependence_up_through_call(self):
     dummy_intrinsic = computation_building_blocks.Intrinsic(
         'dummy_intrinsic', tf.int32)
     ref_to_x = computation_building_blocks.Reference('x', tf.int32)
     identity_lambda = computation_building_blocks.Lambda(
         'x', tf.int32, ref_to_x)
     called_lambda = computation_building_blocks.Call(
         identity_lambda, dummy_intrinsic)
     dependent_nodes = tree_analysis.extract_nodes_consuming(
         called_lambda, dummy_intrinsic_predicate)
     self.assertIn(called_lambda, dependent_nodes)
Ejemplo n.º 30
0
 def _make_sequence_sum_for(type_spec):
   py_typecheck.check_type(type_spec, computation_types.SequenceType)
   if not type_utils.is_sum_compatible(type_spec.element):
     raise TypeError(
         'The value type {} is not compatible with the sum operator.'.format(
             str(type_spec)))
   return value_impl.ValueImpl(
       computation_building_blocks.Intrinsic(
           intrinsic_defs.SEQUENCE_SUM.uri,
           computation_types.FunctionType(type_spec, type_spec.element)),
       self._context_stack)