Exemple #1
0
 def test_returns_federated_broadcast(self):
     value_type = computation_types.FederatedType(tf.int32,
                                                  placements.SERVER, True)
     value = computation_building_blocks.Data('v', value_type)
     comp = computation_constructing_utils.create_federated_broadcast(value)
     self.assertEqual(comp.tff_repr, 'federated_broadcast(v)')
     self.assertEqual(str(comp.type_signature), 'int32@CLIENTS')
 def test_returns_correct_example_of_broadcast_dependent_on_aggregate(self):
     aggregate = computation_test_utils.create_dummy_called_federated_aggregate(
         'accumulate_parameter', 'merge_parameter', 'report_parameter')
     broadcasted_aggregate = computation_constructing_utils.create_federated_broadcast(
         aggregate)
     with self.assertRaisesRegex(ValueError, 'accumulate_parameter'):
         tree_analysis.check_broadcast_not_dependent_on_aggregate(
             broadcasted_aggregate)
def create_dummy_called_federated_broadcast(value_type=tf.int32):
    r"""Returns a dummy called federated broadcast.

                Call
               /    \
  federated_map      data

  Args:
    value_type: The type of the parameter.
  """
    federated_type = computation_types.FederatedType(value_type,
                                                     placements.SERVER)
    value = computation_building_blocks.Data('data', federated_type)
    return computation_constructing_utils.create_federated_broadcast(value)
  def federated_broadcast(self, value):
    """Implements `federated_broadcast` as defined in `api/intrinsics.py`.

    Args:
      value: As in `api/intrinsics.py`.

    Returns:
      As in `api/intrinsics.py`.

    Raises:
      TypeError: As in `api/intrinsics.py`.
    """
    value = value_impl.to_value(value, None, self._context_stack)
    type_utils.check_federated_value_placement(value, placements.SERVER,
                                               'value to be broadcasted')

    if not value.type_signature.all_equal:
      raise TypeError('The broadcasted value should be equal at all locations.')

    value = value_impl.ValueImpl.get_comp(value)
    comp = computation_constructing_utils.create_federated_broadcast(value)
    return value_impl.ValueImpl(comp, self._context_stack)
Exemple #5
0
 def test_raises_type_error_with_none_value(self):
     with self.assertRaises(TypeError):
         computation_constructing_utils.create_federated_broadcast(None)