Ejemplo n.º 1
0
def federated_broadcast(value):
    """Broadcasts a federated value from the `tff.SERVER` to the `tff.CLIENTS`.

  Args:
    value: A value of a TFF federated type placed at the `tff.SERVER`, all
      members of which are equal (the `tff.FederatedType.all_equal` property of
      `value` is `True`).

  Returns:
    A representation of the result of broadcasting: a value of a TFF federated
    type placed at the `tff.CLIENTS`, all members of which are equal.

  Raises:
    TypeError: If the argument is not a federated TFF value placed at the
      `tff.SERVER`.
  """
    value = value_impl.to_value(value, None)
    value = value_utils.ensure_federated_value(value, placements.SERVER,
                                               'value to be broadcasted')

    if not value.type_signature.all_equal:
        raise TypeError(
            'The broadcasted value should be equal at all locations.')

    comp = building_block_factory.create_federated_broadcast(value.comp)
    comp = _bind_comp_as_reference(comp)
    return value_impl.Value(comp)
Ejemplo n.º 2
0
    def test_handles_federated_broadcasts_nested_in_tuple(self):
        first_broadcast = compiler_test_utils.create_whimsy_called_federated_broadcast(
        )
        packed_broadcast = building_blocks.Struct([
            building_blocks.Data(
                'a',
                computation_types.FederatedType(
                    computation_types.TensorType(tf.int32),
                    placements.SERVER)), first_broadcast
        ])
        sel = building_blocks.Selection(packed_broadcast, index=0)
        second_broadcast = building_block_factory.create_federated_broadcast(
            sel)
        result, _ = compiler_transformations.transform_to_call_dominant(
            second_broadcast)
        comp = building_blocks.Lambda('a', tf.int32, result)
        uri = [intrinsic_defs.FEDERATED_BROADCAST.uri]

        before, after = transformations.force_align_and_split_by_intrinsics(
            comp, uri)

        self.assertIsInstance(before, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri))
        self.assertIsInstance(after, building_blocks.Lambda)
        self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
Ejemplo n.º 3
0
 def test_returns_correct_example_of_broadcast_dependent_on_aggregate(self):
   aggregate = test_utils.create_dummy_called_federated_aggregate()
   broadcasted_aggregate = building_block_factory.create_federated_broadcast(
       aggregate)
   with self.assertRaisesRegex(ValueError, 'acc_param'):
     tree_analysis.check_broadcast_not_dependent_on_aggregate(
         broadcasted_aggregate)
Ejemplo n.º 4
0
 def test_finds_broadcast_dependent_on_aggregate(self):
   aggregate = test_utils.create_dummy_called_federated_aggregate()
   broadcasted_aggregate = building_block_factory.create_federated_broadcast(
       aggregate)
   with self.assertRaises(ValueError):
     tree_analysis.check_broadcast_not_dependent_on_aggregate(
         broadcasted_aggregate)
Ejemplo n.º 5
0
def _create_complex_computation():
    tensor_type = computation_types.TensorType(tf.int32)
    compiled = building_block_factory.create_compiled_identity(
        tensor_type, 'a')
    federated_type = computation_types.FederatedType(tf.int32,
                                                     placements.SERVER)
    arg_ref = building_blocks.Reference('arg', federated_type)
    bindings = []
    results = []

    def _bind(name, value):
        bindings.append((name, value))
        return building_blocks.Reference(name, value.type_signature)

    for i in range(2):
        called_federated_broadcast = building_block_factory.create_federated_broadcast(
            arg_ref)
        called_federated_map = building_block_factory.create_federated_map(
            compiled, _bind(f'broadcast_{i}', called_federated_broadcast))
        called_federated_mean = building_block_factory.create_federated_mean(
            _bind(f'map_{i}', called_federated_map), None)
        results.append(_bind(f'mean_{i}', called_federated_mean))
    result = building_blocks.Struct(results)
    block = building_blocks.Block(bindings, result)
    return building_blocks.Lambda('arg', tf.int32, block)
Ejemplo n.º 6
0
 def test_returns_correct_example_of_broadcast_dependent_on_aggregate(self):
     aggregate = computation_test_utils.create_dummy_called_federated_aggregate(
         'accumulate_parameter', 'merge_parameter', 'report_parameter')
     broadcasted_aggregate = building_block_factory.create_federated_broadcast(
         aggregate)
     with self.assertRaisesRegex(ValueError, 'accumulate_parameter'):
         tree_analysis.check_broadcast_not_dependent_on_aggregate(
             broadcasted_aggregate)
Ejemplo n.º 7
0
  def federated_broadcast(self, value):
    """Implements `federated_broadcast` as defined in `api/intrinsics.py`."""
    value = value_impl.to_value(value, None, self._context_stack)
    value = value_utils.ensure_federated_value(value, placements.SERVER,
                                               'value to be broadcasted')

    if not value.type_signature.all_equal:
      raise TypeError('The broadcasted value should be equal at all locations.')

    value = value_impl.ValueImpl.get_comp(value)
    comp = building_block_factory.create_federated_broadcast(value)
    return value_impl.ValueImpl(comp, self._context_stack)
Ejemplo n.º 8
0
def _create_complex_computation():
    compiled = building_block_factory.create_compiled_identity(tf.int32, 'a')
    federated_type = computation_types.FederatedType(tf.int32,
                                                     placements.SERVER)
    ref = building_blocks.Reference('b', federated_type)
    called_federated_broadcast = building_block_factory.create_federated_broadcast(
        ref)
    called_federated_map = building_block_factory.create_federated_map(
        compiled, called_federated_broadcast)
    called_federated_mean = building_block_factory.create_federated_mean(
        called_federated_map, None)
    tup = building_blocks.Tuple([called_federated_mean, called_federated_mean])
    return building_blocks.Lambda('b', tf.int32, tup)
Ejemplo n.º 9
0
 def test_splits_on_selected_intrinsic_nested_in_tuple_broadcast(self):
     first_broadcast = building_block_test_utils.create_whimsy_called_federated_broadcast(
     )
     packed_broadcast = building_blocks.Struct([
         building_blocks.Data('a', computation_types.at_server(tf.int32)),
         first_broadcast
     ])
     sel = building_blocks.Selection(packed_broadcast, index=0)
     second_broadcast = building_block_factory.create_federated_broadcast(
         sel)
     result = transformations.to_call_dominant(second_broadcast)
     comp = building_blocks.Lambda('a', tf.int32, result)
     call = building_block_factory.create_null_federated_broadcast()
     self.assert_splits_on(comp, call)
Ejemplo n.º 10
0
def create_whimsy_called_federated_broadcast(value_type=tf.int32):
    r"""Returns a whimsy called federated broadcast.

                      Call
                     /    \
  federated_broadcast      data

  Args:
    value_type: The type of the value.
  """
    federated_type = computation_types.FederatedType(value_type,
                                                     placements.SERVER)
    value = building_blocks.Data('data', federated_type)
    return building_block_factory.create_federated_broadcast(value)