def test_returns_federated_broadcast(self): value_type = computation_types.FederatedType(tf.int32, placements.SERVER, True) value = computation_building_blocks.Data('v', value_type) comp = computation_constructing_utils.create_federated_broadcast(value) self.assertEqual(comp.tff_repr, 'federated_broadcast(v)') self.assertEqual(str(comp.type_signature), 'int32@CLIENTS')
def test_returns_correct_example_of_broadcast_dependent_on_aggregate(self): aggregate = computation_test_utils.create_dummy_called_federated_aggregate( 'accumulate_parameter', 'merge_parameter', 'report_parameter') broadcasted_aggregate = computation_constructing_utils.create_federated_broadcast( aggregate) with self.assertRaisesRegex(ValueError, 'accumulate_parameter'): tree_analysis.check_broadcast_not_dependent_on_aggregate( broadcasted_aggregate)
def create_dummy_called_federated_broadcast(value_type=tf.int32): r"""Returns a dummy called federated broadcast. Call / \ federated_map data Args: value_type: The type of the parameter. """ federated_type = computation_types.FederatedType(value_type, placements.SERVER) value = computation_building_blocks.Data('data', federated_type) return computation_constructing_utils.create_federated_broadcast(value)
def federated_broadcast(self, value): """Implements `federated_broadcast` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) type_utils.check_federated_value_placement(value, placements.SERVER, 'value to be broadcasted') if not value.type_signature.all_equal: raise TypeError('The broadcasted value should be equal at all locations.') value = value_impl.ValueImpl.get_comp(value) comp = computation_constructing_utils.create_federated_broadcast(value) return value_impl.ValueImpl(comp, self._context_stack)
def test_raises_type_error_with_none_value(self): with self.assertRaises(TypeError): computation_constructing_utils.create_federated_broadcast(None)