示例#1
0
  def federated_map_all_equal(self, fn, arg):
    """Implements `federated_map` as defined in `api/intrinsic.py`.

    Implements `federated_map` as defined in `api/intrinsic.py` with an argument
    with the `all_equal` bit set.

    Args:
      fn: As in `api/intrinsics.py`.
      arg: As in `api/intrinsics.py`, with the `all_equal` bit set.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
    # TODO(b/113112108): Possibly lift the restriction that the mapped value
    # must be placed at the clients after adding support for placement labels
    # in the federated types, and expanding the type specification of the
    # intrinsic this is based on to work with federated values of arbitrary
    # placement.

    arg = value_impl.to_value(arg, None, self._context_stack)
    if isinstance(arg.type_signature, computation_types.NamedTupleType):
      if len(anonymous_tuple.to_elements(arg.type_signature)) >= 2:
        # We've been passed a value which the user expects to be zipped.
        arg = self.federated_zip(arg)
    value_utils.check_federated_value_placement(arg, placements.CLIENTS,
                                                'value to be mapped')

    # TODO(b/113112108): Add support for polymorphic templates auto-instantiated
    # here based on the actual type of the argument.
    fn = value_impl.to_value(fn, None, self._context_stack)

    py_typecheck.check_type(fn, value_base.Value)
    py_typecheck.check_type(fn.type_signature, computation_types.FunctionType)
    if not type_utils.is_assignable_from(fn.type_signature.parameter,
                                         arg.type_signature.member):
      raise TypeError(
          'The mapping function expects a parameter of type {}, but member '
          'constituents of the mapped value are of incompatible type {}.'
          .format(fn.type_signature.parameter, arg.type_signature.member))

    fn = value_impl.ValueImpl.get_comp(fn)
    arg = value_impl.ValueImpl.get_comp(arg)
    comp = computation_constructing_utils.create_federated_map_all_equal(
        fn, arg)
    return value_impl.ValueImpl(comp, self._context_stack)
def create_dummy_called_federated_map_all_equal(parameter_name,
                                                parameter_type=tf.int32):
  r"""Returns a dummy called federated map.

                          Call
                         /    \
  federated_map_all_equal      Tuple
                               |
                               [Lambda(x), data]
                                |
                                Ref(x)

  Args:
    parameter_name: The name of the parameter.
    parameter_type: The type of the parameter.
  """
  fn = create_identity_function(parameter_name, parameter_type)
  arg_type = computation_types.FederatedType(
      parameter_type, placements.CLIENTS, all_equal=True)
  arg = computation_building_blocks.Data('data', arg_type)
  return computation_constructing_utils.create_federated_map_all_equal(fn, arg)
示例#3
0
 def test_converts_federated_map_all_equal_to_federated_map(self):
     fed_type_all_equal = computation_types.FederatedType(
         tf.int32, placements.CLIENTS, all_equal=True)
     normalized_fed_type = computation_types.FederatedType(
         tf.int32, placements.CLIENTS)
     int_ref = building_blocks.Reference('x', tf.int32)
     int_identity = building_blocks.Lambda('x', tf.int32, int_ref)
     federated_int_ref = building_blocks.Reference('y', fed_type_all_equal)
     called_federated_map_all_equal = computation_constructing_utils.create_federated_map_all_equal(
         int_identity, federated_int_ref)
     normalized_federated_map = mapreduce_transformations.normalize_all_equal_bit(
         called_federated_map_all_equal)
     self.assertEqual(called_federated_map_all_equal.function.uri,
                      intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri)
     self.assertIsInstance(normalized_federated_map, building_blocks.Call)
     self.assertIsInstance(normalized_federated_map.function,
                           building_blocks.Intrinsic)
     self.assertEqual(normalized_federated_map.function.uri,
                      intrinsic_defs.FEDERATED_MAP.uri)
     self.assertEqual(normalized_federated_map.type_signature,
                      normalized_fed_type)