def federated_broadcast(value): """Broadcasts a federated value from the `tff.SERVER` to the `tff.CLIENTS`. Args: value: A value of a TFF federated type placed at the `tff.SERVER`, all members of which are equal (the `tff.FederatedType.all_equal` property of `value` is `True`). Returns: A representation of the result of broadcasting: a value of a TFF federated type placed at the `tff.CLIENTS`, all members of which are equal. Raises: TypeError: If the argument is not a federated TFF value placed at the `tff.SERVER`. """ value = value_impl.to_value(value, None) value = value_utils.ensure_federated_value(value, placements.SERVER, 'value to be broadcasted') if not value.type_signature.all_equal: raise TypeError( 'The broadcasted value should be equal at all locations.') comp = building_block_factory.create_federated_broadcast(value.comp) comp = _bind_comp_as_reference(comp) return value_impl.Value(comp)
def test_handles_federated_broadcasts_nested_in_tuple(self): first_broadcast = compiler_test_utils.create_whimsy_called_federated_broadcast( ) packed_broadcast = building_blocks.Struct([ building_blocks.Data( 'a', computation_types.FederatedType( computation_types.TensorType(tf.int32), placements.SERVER)), first_broadcast ]) sel = building_blocks.Selection(packed_broadcast, index=0) second_broadcast = building_block_factory.create_federated_broadcast( sel) result, _ = compiler_transformations.transform_to_call_dominant( second_broadcast) comp = building_blocks.Lambda('a', tf.int32, result) uri = [intrinsic_defs.FEDERATED_BROADCAST.uri] before, after = transformations.force_align_and_split_by_intrinsics( comp, uri) self.assertIsInstance(before, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(before, uri)) self.assertIsInstance(after, building_blocks.Lambda) self.assertFalse(tree_analysis.contains_called_intrinsic(after, uri))
def test_returns_correct_example_of_broadcast_dependent_on_aggregate(self): aggregate = test_utils.create_dummy_called_federated_aggregate() broadcasted_aggregate = building_block_factory.create_federated_broadcast( aggregate) with self.assertRaisesRegex(ValueError, 'acc_param'): tree_analysis.check_broadcast_not_dependent_on_aggregate( broadcasted_aggregate)
def test_finds_broadcast_dependent_on_aggregate(self): aggregate = test_utils.create_dummy_called_federated_aggregate() broadcasted_aggregate = building_block_factory.create_federated_broadcast( aggregate) with self.assertRaises(ValueError): tree_analysis.check_broadcast_not_dependent_on_aggregate( broadcasted_aggregate)
def _create_complex_computation(): tensor_type = computation_types.TensorType(tf.int32) compiled = building_block_factory.create_compiled_identity( tensor_type, 'a') federated_type = computation_types.FederatedType(tf.int32, placements.SERVER) arg_ref = building_blocks.Reference('arg', federated_type) bindings = [] results = [] def _bind(name, value): bindings.append((name, value)) return building_blocks.Reference(name, value.type_signature) for i in range(2): called_federated_broadcast = building_block_factory.create_federated_broadcast( arg_ref) called_federated_map = building_block_factory.create_federated_map( compiled, _bind(f'broadcast_{i}', called_federated_broadcast)) called_federated_mean = building_block_factory.create_federated_mean( _bind(f'map_{i}', called_federated_map), None) results.append(_bind(f'mean_{i}', called_federated_mean)) result = building_blocks.Struct(results) block = building_blocks.Block(bindings, result) return building_blocks.Lambda('arg', tf.int32, block)
def test_returns_correct_example_of_broadcast_dependent_on_aggregate(self): aggregate = computation_test_utils.create_dummy_called_federated_aggregate( 'accumulate_parameter', 'merge_parameter', 'report_parameter') broadcasted_aggregate = building_block_factory.create_federated_broadcast( aggregate) with self.assertRaisesRegex(ValueError, 'accumulate_parameter'): tree_analysis.check_broadcast_not_dependent_on_aggregate( broadcasted_aggregate)
def federated_broadcast(self, value): """Implements `federated_broadcast` as defined in `api/intrinsics.py`.""" value = value_impl.to_value(value, None, self._context_stack) value = value_utils.ensure_federated_value(value, placements.SERVER, 'value to be broadcasted') if not value.type_signature.all_equal: raise TypeError('The broadcasted value should be equal at all locations.') value = value_impl.ValueImpl.get_comp(value) comp = building_block_factory.create_federated_broadcast(value) return value_impl.ValueImpl(comp, self._context_stack)
def _create_complex_computation(): compiled = building_block_factory.create_compiled_identity(tf.int32, 'a') federated_type = computation_types.FederatedType(tf.int32, placements.SERVER) ref = building_blocks.Reference('b', federated_type) called_federated_broadcast = building_block_factory.create_federated_broadcast( ref) called_federated_map = building_block_factory.create_federated_map( compiled, called_federated_broadcast) called_federated_mean = building_block_factory.create_federated_mean( called_federated_map, None) tup = building_blocks.Tuple([called_federated_mean, called_federated_mean]) return building_blocks.Lambda('b', tf.int32, tup)
def test_splits_on_selected_intrinsic_nested_in_tuple_broadcast(self): first_broadcast = building_block_test_utils.create_whimsy_called_federated_broadcast( ) packed_broadcast = building_blocks.Struct([ building_blocks.Data('a', computation_types.at_server(tf.int32)), first_broadcast ]) sel = building_blocks.Selection(packed_broadcast, index=0) second_broadcast = building_block_factory.create_federated_broadcast( sel) result = transformations.to_call_dominant(second_broadcast) comp = building_blocks.Lambda('a', tf.int32, result) call = building_block_factory.create_null_federated_broadcast() self.assert_splits_on(comp, call)
def create_whimsy_called_federated_broadcast(value_type=tf.int32): r"""Returns a whimsy called federated broadcast. Call / \ federated_broadcast data Args: value_type: The type of the value. """ federated_type = computation_types.FederatedType(value_type, placements.SERVER) value = building_blocks.Data('data', federated_type) return building_block_factory.create_federated_broadcast(value)