def federated_map_all_equal(self, fn, arg):
    """`federated_map` with the `all_equal` bit set in the `arg` and return."""
    # TODO(b/113112108): Possibly lift the restriction that the mapped value
    # must be placed at the clients after adding support for placement labels
    # in the federated types, and expanding the type specification of the
    # intrinsic this is based on to work with federated values of arbitrary
    # placement.
    arg = value_impl.to_value(arg, None, self._context_stack)
    arg = value_utils.ensure_federated_value(arg, placements.CLIENTS,
                                             'value to be mapped')

    # TODO(b/113112108): Add support for polymorphic templates auto-instantiated
    # here based on the actual type of the argument.
    fn = value_impl.to_value(fn, None, self._context_stack)

    py_typecheck.check_type(fn, value_base.Value)
    py_typecheck.check_type(fn.type_signature, computation_types.FunctionType)
    if not type_utils.is_assignable_from(fn.type_signature.parameter,
                                         arg.type_signature.member):
      raise TypeError(
          'The mapping function expects a parameter of type {}, but member '
          'constituents of the mapped value are of incompatible type {}.'
          .format(fn.type_signature.parameter, arg.type_signature.member))

    fn = value_impl.ValueImpl.get_comp(fn)
    arg = value_impl.ValueImpl.get_comp(arg)
    comp = building_block_factory.create_federated_map_all_equal(fn, arg)
    return value_impl.ValueImpl(comp, self._context_stack)
示例#2
0
def federated_map_all_equal(fn, arg):
    """`federated_map` with the `all_equal` bit set in the `arg` and return."""
    # TODO(b/113112108): Possibly lift the restriction that the mapped value
    # must be placed at the clients after adding support for placement labels
    # in the federated types, and expanding the type specification of the
    # intrinsic this is based on to work with federated values of arbitrary
    # placement.
    arg = value_impl.to_value(arg, None)
    arg = value_utils.ensure_federated_value(arg, placements.CLIENTS,
                                             'value to be mapped')

    fn = value_impl.to_value(fn,
                             None,
                             parameter_type_hint=arg.type_signature.member)

    py_typecheck.check_type(fn, value_impl.Value)
    py_typecheck.check_type(fn.type_signature, computation_types.FunctionType)
    if not fn.type_signature.parameter.is_assignable_from(
            arg.type_signature.member):
        raise TypeError(
            'The mapping function expects a parameter of type {}, but member '
            'constituents of the mapped value are of incompatible type {}.'.
            format(fn.type_signature.parameter, arg.type_signature.member))

    comp = building_block_factory.create_federated_map_all_equal(
        fn.comp, arg.comp)
    comp = _bind_comp_as_reference(comp)
    return value_impl.Value(comp)
示例#3
0
    def federated_map_all_equal(self, fn, arg):
        """Implements `federated_map` as defined in `api/intrinsic.py`.

    Implements `federated_map` as defined in `api/intrinsic.py` with an argument
    with the `all_equal` bit set.

    Args:
      fn: As in `api/intrinsics.py`.
      arg: As in `api/intrinsics.py`, with the `all_equal` bit set.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
        # TODO(b/113112108): Possibly lift the restriction that the mapped value
        # must be placed at the clients after adding support for placement labels
        # in the federated types, and expanding the type specification of the
        # intrinsic this is based on to work with federated values of arbitrary
        # placement.

        arg = value_impl.to_value(arg, None, self._context_stack)
        if isinstance(arg.type_signature, computation_types.NamedTupleType):
            if len(anonymous_tuple.to_elements(arg.type_signature)) >= 2:
                # We've been passed a value which the user expects to be zipped.
                arg = self.federated_zip(arg)
        value_utils.check_federated_value_placement(arg, placements.CLIENTS,
                                                    'value to be mapped')

        # TODO(b/113112108): Add support for polymorphic templates auto-instantiated
        # here based on the actual type of the argument.
        fn = value_impl.to_value(fn, None, self._context_stack)

        py_typecheck.check_type(fn, value_base.Value)
        py_typecheck.check_type(fn.type_signature,
                                computation_types.FunctionType)
        if not type_utils.is_assignable_from(fn.type_signature.parameter,
                                             arg.type_signature.member):
            raise TypeError(
                'The mapping function expects a parameter of type {}, but member '
                'constituents of the mapped value are of incompatible type {}.'
                .format(fn.type_signature.parameter,
                        arg.type_signature.member))

        fn = value_impl.ValueImpl.get_comp(fn)
        arg = value_impl.ValueImpl.get_comp(arg)
        comp = building_block_factory.create_federated_map_all_equal(fn, arg)
        return value_impl.ValueImpl(comp, self._context_stack)
示例#4
0
 def test_converts_federated_map_all_equal_to_federated_map(self):
     fed_type_all_equal = computation_types.FederatedType(
         tf.int32, placements.CLIENTS, all_equal=True)
     normalized_fed_type = computation_types.FederatedType(
         tf.int32, placements.CLIENTS)
     int_ref = building_blocks.Reference('x', tf.int32)
     int_identity = building_blocks.Lambda('x', tf.int32, int_ref)
     federated_int_ref = building_blocks.Reference('y', fed_type_all_equal)
     called_federated_map_all_equal = building_block_factory.create_federated_map_all_equal(
         int_identity, federated_int_ref)
     normalized_federated_map = mapreduce_transformations.normalize_all_equal_bit(
         called_federated_map_all_equal)
     self.assertEqual(called_federated_map_all_equal.function.uri,
                      intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri)
     self.assertIsInstance(normalized_federated_map, building_blocks.Call)
     self.assertIsInstance(normalized_federated_map.function,
                           building_blocks.Intrinsic)
     self.assertEqual(normalized_federated_map.function.uri,
                      intrinsic_defs.FEDERATED_MAP.uri)
     self.assertEqual(normalized_federated_map.type_signature,
                      normalized_fed_type)
示例#5
0
def create_dummy_called_federated_map_all_equal(parameter_name,
                                                parameter_type=tf.int32):
  r"""Returns a dummy called federated map.

                          Call
                         /    \
  federated_map_all_equal      Tuple
                               |
                               [Lambda(x), data]
                                |
                                Ref(x)

  Args:
    parameter_name: The name of the parameter.
    parameter_type: The type of the parameter.
  """
  fn = create_identity_function(parameter_name, parameter_type)
  arg_type = computation_types.FederatedType(
      parameter_type, placements.CLIENTS, all_equal=True)
  arg = building_blocks.Data('data', arg_type)
  return building_block_factory.create_federated_map_all_equal(fn, arg)