def _instantiate_aggregation_process( aggregation_factory, model_weights_type, client_weight_fn) -> tff.templates.AggregationProcess: """Constructs aggregation process given factory, checking compatibilty.""" if aggregation_factory is None: aggregation_factory = tff.aggregators.MeanFactory() aggregation_process = aggregation_factory.create( model_weights_type.trainable, tff.TensorType(tf.float32)) else: # We give precedence to unweighted aggregation. if isinstance(aggregation_factory, tff.aggregators.UnweightedAggregationFactory): if client_weight_fn is not None: logging.warning( 'When using an unweighted aggregation, ' '`client_weight_fn` should not be specified; found ' '`client_weight_fn` %s', client_weight_fn) aggregation_process = aggregation_factory.create( model_weights_type.trainable) elif isinstance(aggregation_factory, tff.aggregators.WeightedAggregationFactory): aggregation_process = aggregation_factory.create( model_weights_type.trainable, tff.TensorType(tf.float32)) else: raise ValueError('Unknown type of aggregation factory: {}'.format( type(aggregation_factory))) return aggregation_process
def test_build_jax_federated_averaging_process(self): batch_type = collections.OrderedDict([ ('pixels', tff.TensorType(np.float32, (50, 784))), ('labels', tff.TensorType(np.int32, (50, ))) ]) def random_batch(): pixels = np.random.uniform(low=0.0, high=1.0, size=(50, 784)).astype(np.float32) labels = np.random.randint(low=0, high=9, size=(50, ), dtype=np.int32) return collections.OrderedDict([('pixels', pixels), ('labels', labels)]) model_type = collections.OrderedDict([ ('weights', tff.TensorType(np.float32, (784, 10))), ('bias', tff.TensorType(np.float32, (10, ))) ]) def loss(model, batch): y = jax.nn.softmax( jax.numpy.add( jax.numpy.matmul(batch['pixels'], model['weights']), model['bias'])) targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10) return -jax.numpy.mean( jax.numpy.sum(targets * jax.numpy.log(y), axis=1)) trainer = jax_components.build_jax_federated_averaging_process( batch_type, model_type, loss, step_size=0.001) trainer.next(trainer.initialize(), [[random_batch()]])
def test_federated_mean_masked(self): value_type = tff.StructType([('a', tff.TensorType(tf.float32, shape=[3])), ('b', tff.TensorType(tf.float32, shape=[2, 1]))]) weight_type = tff.TensorType(tf.float32) federated_mean_masked_fn = fed_pa_schedule.build_federated_mean_masked( value_type, weight_type) # Check type signature. expected_type = tff.FunctionType( parameter=collections.OrderedDict( value=tff.FederatedType(value_type, tff.CLIENTS), weight=tff.FederatedType(weight_type, tff.CLIENTS)), result=tff.FederatedType(value_type, tff.SERVER)) self.assertTrue( federated_mean_masked_fn.type_signature.is_equivalent_to(expected_type), msg='{s}\n!={t}'.format( s=federated_mean_masked_fn.type_signature, t=expected_type)) # Check correctness of zero masking in the mean. values = [ collections.OrderedDict( a=tf.constant([0.0, 1.0, 1.0]), b=tf.constant([[1.0], [2.0]])), collections.OrderedDict( a=tf.constant([1.0, 3.0, 0.0]), b=tf.constant([[3.0], [0.0]])) ] weights = [tf.constant(1.0), tf.constant(3.0)] output = federated_mean_masked_fn(values, weights) expected_output = collections.OrderedDict( a=tf.constant([1.0, 2.5, 1.0]), b=tf.constant([[2.5], [2.0]])) self.assertAllClose(output, expected_output)
def test_iterative_process_type_signature(self): client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) lr_callback_type = tff.framework.type_from_tensors(client_lr_callback) server_state_type = tff.FederatedType( adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights( trainable=(tff.TensorType(tf.float32, [1, 1]), tff.TensorType(tf.float32, [1])), non_trainable=()), optimizer_state=[tf.int64], client_lr_callback=lr_callback_type, server_lr_callback=lr_callback_type), tff.SERVER) self.assertEqual( iterative_process.initialize.type_signature, tff.FunctionType(parameter=None, result=server_state_type)) dataset_type = tff.FederatedType( tff.SequenceType( collections.OrderedDict( x=tff.TensorType(tf.float32, [None, 1]), y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS) metrics_type = tff.FederatedType( collections.OrderedDict(loss=tff.TensorType(tf.float32)), tff.SERVER) output_type = collections.OrderedDict(before_training=metrics_type, during_training=metrics_type) expected_result_type = (server_state_type, output_type) expected_type = tff.FunctionType(parameter=collections.OrderedDict( server_state=server_state_type, federated_dataset=dataset_type), result=expected_result_type) actual_type = iterative_process.next.type_signature self.assertEqual(actual_type, expected_type, msg='{s}\n!={t}'.format(s=actual_type, t=expected_type))
def create_preprocess_fn( num_epochs: int, batch_size: int, shuffle_buffer_size: int = MAX_CLIENT_DATASET_SIZE, emnist_task: str = 'digit_recognition', num_parallel_calls: tf.Tensor = tf.data.experimental.AUTOTUNE ) -> tff.Computation: """Creates a preprocessing function for EMNIST client datasets. The preprocessing shuffles, repeats, batches, and then reshapes, using the `shuffle`, `repeat`, `batch`, and `map` attributes of a `tf.data.Dataset`, in that order. Args: num_epochs: An integer representing the number of epochs to repeat the client datasets. batch_size: An integer representing the batch size on clients. shuffle_buffer_size: An integer representing the shuffle buffer size on clients. If set to a number <= 1, no shuffling occurs. emnist_task: A string indicating the EMNIST task being performed. Must be one of 'digit_recognition' or 'autoencoder'. If the former, then elements are mapped to tuples of the form (pixels, label), if the latter then elements are mapped to tuples of the form (pixels, pixels). num_parallel_calls: An integer representing the number of parallel calls used when performing `tf.data.Dataset.map`. Returns: A `tff.Computation` performing the preprocessing discussed above. """ if num_epochs < 1: raise ValueError('num_epochs must be a positive integer.') if shuffle_buffer_size <= 1: shuffle_buffer_size = 1 if emnist_task == 'digit_recognition': mapping_fn = _reshape_for_digit_recognition elif emnist_task == 'autoencoder': mapping_fn = _reshape_for_autoencoder else: raise ValueError('emnist_task must be one of "digit_recognition" or ' '"autoencoder".') # Features are intentionally sorted lexicographically by key for consistency # across datasets. feature_dtypes = collections.OrderedDict(label=tff.TensorType(tf.int32), pixels=tff.TensorType(tf.float32, shape=(28, 28))) @tff.tf_computation(tff.SequenceType(feature_dtypes)) def preprocess_fn(dataset): return dataset.shuffle(shuffle_buffer_size).repeat(num_epochs).batch( batch_size, drop_remainder=False).map(mapping_fn, num_parallel_calls=num_parallel_calls) return preprocess_fn
def create_preprocess_fn( num_epochs: int, batch_size: int, shuffle_buffer_size: int = NUM_EXAMPLES_PER_CLIENT, crop_shape: Tuple[int, int, int] = CIFAR_SHAPE, distort_image=False, num_parallel_calls: int = tf.data.experimental.AUTOTUNE) -> tff.Computation: """Creates a preprocessing function for CIFAR-100 client datasets. Args: num_epochs: An integer representing the number of epochs to repeat the client datasets. batch_size: An integer representing the batch size on clients. shuffle_buffer_size: An integer representing the shuffle buffer size on clients. If set to a number <= 1, no shuffling occurs. crop_shape: A tuple (crop_height, crop_width, num_channels) specifying the desired crop shape for pre-processing. This tuple cannot have elements exceeding (32, 32, 3), element-wise. The element in the last index should be set to 3 to maintain the RGB image structure of the elements. distort_image: A boolean indicating whether to perform preprocessing that includes image distortion, including random crops and flips. num_parallel_calls: An integer representing the number of parallel calls used when performing `tf.data.Dataset.map`. Returns: A `tff.Computation` performing the preprocessing described above. """ if num_epochs < 1: raise ValueError('num_epochs must be a positive integer.') if shuffle_buffer_size <= 1: shuffle_buffer_size = 1 # Features are intentionally sorted lexicographically by key for consistency # across datasets. feature_dtypes = collections.OrderedDict( coarse_label=tff.TensorType(tf.int64), image=tff.TensorType(tf.uint8, shape=(32, 32, 3)), label=tff.TensorType(tf.int64)) image_map_fn = build_image_map(crop_shape, distort_image) @tff.tf_computation(tff.SequenceType(feature_dtypes)) def preprocess_fn(dataset): return ( dataset.shuffle(shuffle_buffer_size).repeat(num_epochs) # We map before batching to ensure that the cropping occurs # at an image level (eg. we do not perform the same crop on # every image within a batch) .map(image_map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size)) return preprocess_fn
def test_dynamic_lookup_table(self): @tff.tf_computation(tff.TensorType(shape=[None], dtype=tf.string), tff.TensorType(shape=[None], dtype=tf.string)) def comp(table_args, to_lookup): values = tf.range(tf.shape(table_args)[0]) initializer = tf.lookup.KeyValueTensorInitializer( table_args, values) table = tf.lookup.StaticHashTable(initializer, default_value=101) return table.lookup(to_lookup) result = comp(tf.constant(['a', 'b', 'c']), tf.constant(['a', 'z', 'c'])) self.assertAllEqual(result, [0, 101, 2])
def test_lookup_table(self): @tff.tf_computation(tff.TensorType(shape=[None], dtype=tf.string), tff.TensorType(shape=[], dtype=tf.string)) def foo(table_args, to_lookup): values = tf.range(tf.shape(table_args)[0]) initializer = tf.lookup.KeyValueTensorInitializer( table_args, values) table = tf.lookup.StaticHashTable(initializer, 100) return table.lookup(to_lookup) self.assertEqual(foo(tf.constant(['a', 'b']), 'a'), 0) self.assertEqual(foo(tf.constant(['d', 'e', 'f']), 'f'), 2) self.assertEqual(foo(tf.constant(['d', 'e', 'f', 'g', 'h', 'i']), 'i'), 5)
def test_reinitialize_dynamic_lookup_table(self): @tff.tf_computation(tff.TensorType(shape=[None], dtype=tf.string), tff.TensorType(shape=[], dtype=tf.string)) def comp(table_args, to_lookup): values = tf.range(tf.shape(table_args)[0]) initializer = tf.lookup.KeyValueTensorInitializer( table_args, values) table = tf.lookup.StaticHashTable(initializer, default_value=101) return table.lookup(to_lookup) expected_zero = comp(tf.constant(['a', 'b', 'c']), tf.constant('a')) expected_three = comp(tf.constant(['a', 'b', 'c', 'd']), tf.constant('d')) self.assertEqual(expected_zero, 0) self.assertEqual(expected_three, 3)
def test_bad_type_coercion_raises(self): tensor_type = tff.TensorType(shape=[None], dtype=tf.float32) @tff.tf_computation(tensor_type) def foo(x): # We will pass in a tensor which passes the TFF type check, but fails the # reshape. return tf.reshape(x, []) @tff.federated_computation(tff.type_at_clients(tensor_type)) def map_foo_at_clients(x): return tff.federated_map(foo, x) @tff.federated_computation(tff.type_at_server(tensor_type)) def map_foo_at_server(x): return tff.federated_map(foo, x) bad_tensor = tf.constant([1.] * 10, dtype=tf.float32) good_tensor = tf.constant([1.], dtype=tf.float32) # Ensure running this computation at both placements, or unplaced, still # raises. with self.assertRaises(Exception): foo(bad_tensor) with self.assertRaises(Exception): map_foo_at_server(bad_tensor) with self.assertRaises(Exception): map_foo_at_clients([bad_tensor] * 10) # We give the distributed runtime a chance to clean itself up, otherwise # workers may be getting SIGABRT while they are handling another exception, # causing the test infra to crash. Making a successful call ensures that # cleanup happens after failures have been handled. map_foo_at_clients([good_tensor] * 10)
def create_preprocess_fn( num_epochs: int, batch_size: int, shuffle_buffer_size: int, mapping_fn: Callable[[Any], Any], num_parallel_calls: tf.Tensor = tf.data.experimental.AUTOTUNE ) -> tff.Computation: """Creates a preprocessing function for EMNIST client datasets. The preprocessing shuffles, repeats, batches, and then reshapes, using the `shuffle`, `repeat`, `batch`, and `map` attributes of a `tf.data.Dataset`, in that order. Args: num_epochs: An integer representing the number of epochs to repeat the client datasets. batch_size: An integer representing the batch size on clients. shuffle_buffer_size: An integer representing the shuffle buffer size on clients. If set to a number <= 1, no shuffling occurs. mapping_fn: A mapping function applied to EMNIST elements after shuffling, repeating, and batching occurs. num_parallel_calls: An integer representing the number of parallel calls used when performing `tf.data.Dataset.map`. Returns: A `tff.Computation` performing the preprocessing discussed above. """ if num_epochs < 1: raise ValueError('num_epochs must be a positive integer.') if shuffle_buffer_size <= 1: shuffle_buffer_size = 1 feature_dtypes = collections.OrderedDict(pixels=tff.TensorType(tf.float32, shape=(28, 28)), label=tff.TensorType(tf.int32)) @tff.tf_computation(tff.SequenceType(feature_dtypes)) def preprocess_fn(dataset): return dataset.shuffle(shuffle_buffer_size).repeat(num_epochs).batch( batch_size, drop_remainder=False).map(mapping_fn, num_parallel_calls=num_parallel_calls) return preprocess_fn
def make_integer_secure_sum(input_shape): if input_shape is None: member_type = tf.int32 else: member_type = tff.TensorType(tf.int32, input_shape) @tff.federated_computation(tff.FederatedType(member_type, tff.CLIENTS)) def secure_paillier_addition(x): return tff.federated_secure_sum(x, 64) return secure_paillier_addition
def get_preprocessing_fn(image_size: int, batch_size: int, num_epochs: int, max_elements: int, shuffle_buffer_size: int) -> tff.Computation: """Creates a preprocessing function for federated training. Args: image_size: The height and width of images after preprocessing. batch_size: Batch size used for training. num_epochs: Number of training epochs. max_elements: The maximum number of elements taken from the dataset. It has to be a positive value or -1 (which means all elements are taken). shuffle_buffer_size: Buffer size used in shuffling. Returns: A `tff.Computation` that transforms the raw `tf.data.Dataset` of a client into a `tf.data.Dataset` that is ready for training. """ _check_positive('image_size', image_size) _check_positive('batch_size', batch_size) _check_positive('num_epochs', num_epochs) _check_positive('shuffle_buffer_size', shuffle_buffer_size) if max_elements <= 0 and max_elements != -1: raise ValueError('Expected a positive value or -1 for `max_elements`, ' f'found {max_elements}.') feature_dtypes = collections.OrderedDict() feature_dtypes['image/decoded'] = tff.TensorType( dtype=tf.uint8, shape=[None, None, None]) feature_dtypes['class'] = tff.TensorType(dtype=tf.int64, shape=[1]) @tff.tf_computation(tff.SequenceType(feature_dtypes)) def preprocessing_fn(dataset: tf.data.Dataset) -> tf.data.Dataset: dataset = dataset.map( functools.partial(_map_fn, is_training=True, image_size=image_size), num_parallel_calls=tf.data.experimental.AUTOTUNE).shuffle( shuffle_buffer_size).take(max_elements).repeat(num_epochs).batch( batch_size) return dataset return preprocessing_fn
def test_eval_fn_has_correct_type_signature(self): metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()] eval_fn = evaluation.build_centralized_evaluation( tff_model_fn, metrics_builder) actual_type = eval_fn.type_signature model_type = tff.FederatedType( tff.learning.ModelWeights( trainable=( tff.TensorType(tf.float32, [1, 1]), tff.TensorType(tf.float32, [1]), ), non_trainable=(), ), tff.SERVER) dataset_type = tff.FederatedType( tff.SequenceType( collections.OrderedDict( x=tff.TensorType(tf.float32, [None, 1]), y=tff.TensorType(tf.float32, [None, 1]))), tff.SERVER) metrics_type = tff.FederatedType( collections.OrderedDict( mean_squared_error=tff.TensorType(tf.float32), num_examples=tff.TensorType(tf.float32)), tff.SERVER) expected_type = tff.FunctionType(parameter=collections.OrderedDict( model_weights=model_type, centralized_dataset=dataset_type), result=metrics_type) actual_type.check_assignable_from(expected_type)
def test_secure_sum_larger_matrices(self, first_dim, second_dim): NUM_CLIENTS = 5 shape = (first_dim, second_dim) input_tensor = np.ones(shape, dtype=np.int32) member_type = tff.TensorType(tf.int32, shape) @tff.federated_computation(tff.FederatedType(member_type, tff.CLIENTS)) def secure_paillier_addition(x): return tff.federated_secure_sum(x, 64) with _install_executor(factory.local_paillier_executor_factory()): result = secure_paillier_addition([input_tensor] * NUM_CLIENTS) expected = input_tensor * NUM_CLIENTS np.testing.assert_almost_equal(result, expected)
def test_iterative_process_type_signature(self): iterative_process = decay_iterative_process_builder.from_flags( input_spec=get_input_spec(), model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) dummy_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.client_learning_rate, decay_factor=FLAGS.client_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) lr_callback_type = tff.framework.type_from_tensors(dummy_lr_callback) server_state_type = tff.FederatedType( adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights( trainable=(tff.TensorType(tf.float32, [1, 1]), tff.TensorType(tf.float32, [1])), non_trainable=()), optimizer_state=[tf.int64], client_lr_callback=lr_callback_type, server_lr_callback=lr_callback_type), tff.SERVER) self.assertEqual( iterative_process.initialize.type_signature, tff.FunctionType(parameter=None, result=server_state_type)) dataset_type = tff.FederatedType( tff.SequenceType( collections.OrderedDict( x=tff.TensorType(tf.float32, [None, 1]), y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS) metrics_type = tff.FederatedType( collections.OrderedDict( mean_squared_error=tff.TensorType(tf.float32), loss=tff.TensorType(tf.float32)), tff.SERVER) output_type = collections.OrderedDict(before_training=metrics_type, during_training=metrics_type) expected_result_type = (server_state_type, output_type) expected_type = tff.FunctionType(parameter=(server_state_type, dataset_type), result=expected_result_type) actual_type = iterative_process.next.type_signature self.assertTrue(actual_type.is_equivalent_to(expected_type))
def iterative_process_builder( model_fn: Callable[[], tff.learning.Model], client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None, ) -> tff.templates.IterativeProcess: factory = importance_aggregation_factory.ImportanceSamplingFactory( FLAGS.clients_per_round) weights_type = importance_aggregation_factory.weights_type_from_model_fn( model_fn) importance_aggregation_process = factory.create( value_type=weights_type, weight_type=tff.TensorType(tf.float32)) return importance_schedule.build_fed_avg_process( model_fn=model_fn, client_optimizer_fn=client_optimizer_fn, client_lr=client_lr_schedule, server_optimizer_fn=server_optimizer_fn, server_lr=server_lr_schedule, aggregation_process=importance_aggregation_process)
async def _compute_reshape_on_tensor(self, tensor, output_shape): tensor_type = tensor.type_signature.member shape_type = type_conversions.infer_type(output_shape) reshaper_proto, reshaper_type = utils.materialize_computation_from_cache( paillier_comp.make_reshape_tensor, self._reshape_function_cache, arg_spec=(tensor_type, ), output_shape=output_shape) tensor_placement = tensor.type_signature.placement children = self._get_child_executors(tensor_placement) py_typecheck.check_len(tensor.internal_representation, len(children)) reshaper_fns = await asyncio.gather(*[ ex.create_value(reshaper_proto, reshaper_type) for ex in children ]) reshaped_tensors = await asyncio.gather(*[ ex.create_call(fn, arg) for ex, fn, arg in zip( children, reshaper_fns, tensor.internal_representation) ]) output_tensor_spec = tff.FederatedType( tff.TensorType(tensor_type.dtype, output_shape), tensor_placement, tensor.type_signature.all_equal) return federated_resolving_strategy.FederatedResolvingStrategyValue( reshaped_tensors, output_tensor_spec)
import numpy as np import tensorflow as tf import tensorflow_federated as tff from federated_aggregations import paillier NUM_CLIENTS = 5 paillier_factory = paillier.local_paillier_executor_factory(NUM_CLIENTS) paillier_context = tff.framework.ExecutionContext(paillier_factory) tff.framework.set_default_context(paillier_context) @tff.federated_computation( tff.FederatedType(tff.TensorType(tf.int32, [2]), tff.CLIENTS), tff.TensorType(tf.int32)) def secure_paillier_addition(x, bitwidth): return tff.federated_secure_sum(x, bitwidth) base = np.array([1, 2], np.int32) x = [base + i for i in range(NUM_CLIENTS)] result = secure_paillier_addition(x, 32) print(result)
dtype=np.float32), 'y': np.array([source[1][i] for i in batch_samples], dtype=np.int32)}) # add noise 0x-0.2x ratio = num * 0.05 sum_agent = int(len(all_samples)) index = 0 for i in range(0, sum_agent): noiseHere = ratio * np.random.randn(28*28) output_sequence[int(i/BATCH_SIZE)]['x'][i % BATCH_SIZE] = checkRange(np.add( output_sequence[int(i/BATCH_SIZE)]['x'][i % BATCH_SIZE], noiseHere)) return output_sequence BATCH_TYPE = tff.NamedTupleType([ ('x', tff.TensorType(tf.float32, [None, 784])), ('y', tff.TensorType(tf.int32, [None]))]) MODEL_TYPE = tff.NamedTupleType([ ('weights', tff.TensorType(tf.float32, [784, 10])), ('bias', tff.TensorType(tf.float32, [10]))]) @tff.tf_computation(MODEL_TYPE, BATCH_TYPE) def batch_loss(model, batch): predicted_y = tf.nn.softmax(tf.matmul(batch.x, model.weights) + model.bias) return -tf.reduce_mean(tf.reduce_sum( tf.one_hot(batch.y, 10) * tf.log(predicted_y), axis=[1])) @tff.tf_computation(MODEL_TYPE, BATCH_TYPE, tf.float32)
def run_one_round(server_state, federated_dataset, ids): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_round_num, ids)) client_weight = client_outputs.client_weight client_id = client_outputs.client_id #LOSS SELECTION: # losses_at_server = tff.federated_collect(client_outputs.model_output) # weights_at_server = tff.federated_collect(client_weight) @computations.tf_computation def zeros_fn(): return tf.zeros(shape=[total_clients, 1], dtype=tf.float32) zero = zeros_fn() at_server_type = tff.TensorType(shape=[total_clients, 1], dtype=tf.float32) # list_type = tff.SequenceType( tff.TensorType(dtype=tf.float32)) client_output_type = client_update_fn.type_signature.result @computations.tf_computation(at_server_type, client_output_type) def accumulate_weight(u, t): value = t.client_weight index = t.client_id new_u = tf.tensor_scatter_nd_update(u, index, value) return new_u @computations.tf_computation(at_server_type, client_output_type) def accumulate_loss(u, t): value = tf.reshape(tf.math.reduce_sum(t.model_output['loss']), shape=[1, 1]) index = t.client_id new_u = tf.tensor_scatter_nd_update(u, index, value) return new_u # output_at_server= tff.federated_collect(client_outputs) weights_at_server = tff.federated_reduce(client_outputs, zero, accumulate_weight) losses_at_server = tff.federated_reduce(client_outputs, zero, accumulate_loss) #losses_at_server = tff.federated_aggregate(client_outputs.model_output, zero, accumulate, merge, report) selected_clients_weights = tff.federated_map( zero_small_loss_clients, (losses_at_server, weights_at_server, server_state.effective_num_clients)) # selected_clients_weights_at_client = tff.federated_broadcast(selected_clients_weights) selected_clients_weights_broadcast = tff.federated_broadcast( selected_clients_weights) selected_clients_weights_at_client = tff.federated_map( select_weight_fn, (selected_clients_weights_broadcast, ids)) aggregation_output = aggregation_process.next( server_state.delta_aggregate_state, client_outputs.weights_delta, selected_clients_weights_at_client) # model_delta = tff.federated_mean( # client_outputs.weights_delta, weight=client_weight) server_state = tff.federated_map( server_update_fn, (server_state, aggregation_output.result)) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs
def build_triehh_process( possible_prefix_extensions: List[str], num_sub_rounds: int, max_num_prefixes: int, threshold: int, max_user_contribution: int, default_terminator: str = triehh_tf.DEFAULT_TERMINATOR): """Builds the TFF computations for heavy hitters discovery with TrieHH. TrieHH works by interactively keeping track of popular prefixes. In each round, the server broadcasts the popular prefixes it has discovered so far and the list of `possible_prefix_extensions` to a small fraction of selected clients. The select clients sample `max_user_contributions` words from their local datasets, and use them to vote on character extensions to the broadcasted popular prefixes. Client votes are accumulated across `num_sub_rounds` rounds, and then the top `max_num_prefixes` extensions get at least 'threshold' votes are used to extend the already discovered prefixes, and the extended prefixes are used in the next round. When an already discovered prefix is extended by `default_terminator` it is added to the list of discovered heavy hitters. Args: possible_prefix_extensions: A list containing all the possible extensions to learned prefixes. Each extensions must be a single character strings. This list should not contain the default_terminator. num_sub_rounds: The total number of sub rounds to be executed before decoding aggregated votes. Must be positive. max_num_prefixes: The maximum number of prefixes we can keep in the trie. Must be positive. threshold: The threshold for heavy hitters and discovered prefixes. Only those get at least `threshold` votes are discovered. Must be positive. max_user_contribution: The maximum number of examples a user can contribute. Must be positive. default_terminator: The end of sequence symbol. Returns: A `tff.templates.IterativeProcess`. Raises: ValueError: If possible_prefix_extensions contains default_terminator. """ if default_terminator in possible_prefix_extensions: raise ValueError( 'default_terminator should not appear in possible_prefix_extensions' ) # Append `default_terminator` to `possible_prefix_extensions` to make sure it # is the last item in the list. possible_prefix_extensions.append(default_terminator) @tff.tf_computation def server_init_tf(): return ServerState( discovered_heavy_hitters=tf.constant([], dtype=tf.string), heavy_hitters_frequencies=tf.constant([], dtype=tf.float64), discovered_prefixes=tf.constant([''], dtype=tf.string), round_num=tf.constant(0, dtype=tf.int32), accumulated_votes=tf.zeros( dtype=tf.int32, shape=[max_num_prefixes, len(possible_prefix_extensions)]), accumulated_weights=tf.constant(0, dtype=tf.int32)) # We cannot use server_init_tf.type_signature.result because the # discovered_* fields need to have [None] shapes, since they will grow over # time. server_state_type = (tff.to_type( ServerState( discovered_heavy_hitters=tff.TensorType(dtype=tf.string, shape=[None]), heavy_hitters_frequencies=tff.TensorType(dtype=tf.float64, shape=[None]), discovered_prefixes=tff.TensorType(dtype=tf.string, shape=[None]), round_num=tff.TensorType(dtype=tf.int32, shape=[]), accumulated_votes=tff.TensorType( dtype=tf.int32, shape=[None, len(possible_prefix_extensions)]), accumulated_weights=tff.TensorType(dtype=tf.int32, shape=[]), ))) sub_round_votes_type = tff.TensorType( dtype=tf.int32, shape=[max_num_prefixes, len(possible_prefix_extensions)]) sub_round_weight_type = tff.TensorType(dtype=tf.int32, shape=[]) @tff.tf_computation(server_state_type, sub_round_votes_type, sub_round_weight_type) def server_update_fn(server_state, sub_round_votes, sub_round_weight): return server_update(server_state, tf.constant(possible_prefix_extensions), sub_round_votes, sub_round_weight, num_sub_rounds=tf.constant(num_sub_rounds), max_num_prefixes=tf.constant(max_num_prefixes), threshold=tf.constant(threshold)) tf_dataset_type = tff.SequenceType(tf.string) discovered_prefixes_type = tff.TensorType(dtype=tf.string, shape=[None]) round_num_type = tff.TensorType(dtype=tf.int32, shape=[]) @tff.tf_computation(tf_dataset_type, discovered_prefixes_type, round_num_type) def client_update_fn(tf_dataset, discovered_prefixes, round_num): return client_update(tf_dataset, discovered_prefixes, tf.constant(possible_prefix_extensions), round_num, num_sub_rounds, max_num_prefixes, max_user_contribution) federated_server_state_type = tff.FederatedType(server_state_type, tff.SERVER) federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS, all_equal=False) @tff.federated_computation(federated_server_state_type, federated_dataset_type) def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of TrieHH computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: An updated `ServerState` """ discovered_prefixes = tff.federated_broadcast( server_state.discovered_prefixes) round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, discovered_prefixes, round_num)) accumulated_votes = tff.federated_sum(client_outputs.client_votes) accumulated_weights = tff.federated_sum(client_outputs.client_weight) server_state = tff.federated_map( server_update_fn, (server_state, accumulated_votes, accumulated_weights)) server_output = tff.federated_value([], tff.SERVER) return server_state, server_output return tff.templates.IterativeProcess( initialize_fn=tff.federated_computation( lambda: tff.federated_eval(server_init_tf, tff.SERVER)), next_fn=run_one_round)
def test_build_with_preprocess_funtion(self): test_dataset = tf.data.Dataset.range(5) client_datasets_type = tff.FederatedType( tff.SequenceType(test_dataset.element_spec), tff.CLIENTS) @tff.tf_computation(tff.SequenceType(test_dataset.element_spec)) def preprocess_dataset(ds): def to_batch(x): return collections.OrderedDict(x=[float(x) * 1.0], y=[float(x) * 3.0 + 1.0]) return ds.map(to_batch).repeat().batch(2).take(3) client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD, dataset_preprocess_comp=preprocess_dataset) lr_callback_type = tff.framework.type_from_tensors(client_lr_callback) server_state_type = tff.FederatedType( adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights( trainable=(tff.TensorType(tf.float32, [1, 1]), tff.TensorType(tf.float32, [1])), non_trainable=()), optimizer_state=[tf.int64], client_lr_callback=lr_callback_type, server_lr_callback=lr_callback_type), tff.SERVER) self.assertEqual( iterative_process.initialize.type_signature, tff.FunctionType(parameter=None, result=server_state_type)) metrics_type = tff.FederatedType( collections.OrderedDict(loss=tff.TensorType(tf.float32)), tff.SERVER) output_type = collections.OrderedDict(before_training=metrics_type, during_training=metrics_type) expected_result_type = (server_state_type, output_type) expected_type = tff.FunctionType(parameter=collections.OrderedDict( server_state=server_state_type, federated_dataset=client_datasets_type), result=expected_result_type) actual_type = iterative_process.next.type_signature self.assertEqual(actual_type, expected_type, msg='{s}\n!={t}'.format(s=actual_type, t=expected_type))
def __init__( self, model_fn, m, n, j_max, importance_sampling, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0), client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1), ): """Builds the TFF computations for optimization using federated averaging. Args: model_fn: A no-arg function that returns a `simple_fedavg_tf.KerasModelWrapper`. server_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` for server update. client_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` for client update. Returns: A `tff.templates.IterativeProcess`. """ dummy_model = model_fn() @tff.tf_computation def server_init_tf(): model = model_fn() server_optimizer = server_optimizer_fn() _initialize_optimizer_vars(model, server_optimizer) return ServerState(model_weights=model.weights, optimizer_state=server_optimizer.variables(), round_num=0) server_state_type = server_init_tf.type_signature.result model_weights_type = server_state_type.model_weights @tff.tf_computation(server_state_type, model_weights_type.trainable) def server_update_fn(server_state, model_delta): model = model_fn() server_optimizer = server_optimizer_fn() _initialize_optimizer_vars(model, server_optimizer) return server_update(model, server_optimizer, server_state, model_delta) @tff.tf_computation(server_state_type) def server_message_fn(server_state): return build_server_broadcast_message(server_state) server_message_type = server_message_fn.type_signature.result tf_dataset_type = tff.SequenceType(dummy_model.input_spec) @tff.tf_computation(tf_dataset_type, server_message_type) def client_update_fn(tf_dataset, server_message): model = model_fn() client_optimizer = client_optimizer_fn() return client_update(model, tf_dataset, server_message, client_optimizer) federated_server_state_type = tff.FederatedType( server_state_type, tff.SERVER) federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS) @tff.tf_computation( tf.float32, tf.float32, ) def scale(update_norm, sum_update_norms): if importance_sampling: return tf.minimum( 1., tf.divide(tf.multiply(update_norm, m), sum_update_norms)) else: return tf.divide(m, n) @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS), tff.FederatedType(tf.float32, tff.CLIENTS, True)) def scale_on_clients(update_norm, sum_update_norms): return tff.federated_map(scale, (update_norm, sum_update_norms)) @tff.tf_computation(tf.float32) def create_prob_message(prob): def f1(): return tf.stack([prob, 1.]) def f2(): return tf.constant([0., 0.]) prob_message = tf.cond(tf.less(prob, 1), f1, f2) return prob_message @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS)) def create_prob_message_on_clients(prob): return tff.federated_map(create_prob_message, prob) @tff.tf_computation(tff.TensorType(tf.float32, (2, ))) def compute_rescaling(prob_aggreg): rescaling_factor = (m - n + prob_aggreg[1]) / prob_aggreg[0] return rescaling_factor @tff.federated_computation( tff.FederatedType(tff.TensorType(tf.float32, (2, )), tff.SERVER)) def compute_rescaling_on_master(prob_aggreg): return tff.federated_map(compute_rescaling, prob_aggreg) @tff.tf_computation(tf.float32, tf.float32) def rescale_prob(prob, rescaling_factor): return tf.minimum(1., tf.multiply(prob, rescaling_factor)) @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS), tff.FederatedType(tf.float32, tff.CLIENTS, True)) def rescale_prob_on_clients(rob, rescaling_factor): return tff.federated_map(rescale_prob, (rob, rescaling_factor)) @tff.tf_computation(tf.float32) def compute_weights_is_fn(prob): def f1(): return 1. / prob def f2(): return 0. weight = tf.cond(tf.less(tf.random.uniform(()), prob), f1, f2) return weight @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS)) def compute_weights_is(prob): return tff.federated_map(compute_weights_is_fn, prob) @tff.federated_computation( tff.FederatedType(model_weights_type.trainable, tff.CLIENTS), tff.FederatedType(tf.float32, tff.CLIENTS)) def compute_round_model_delta(weights_delta, weights_denom): return tff.federated_mean(weights_delta, weight=weights_denom) @tff.federated_computation(federated_server_state_type, tff.FederatedType( model_weights_type.trainable, tff.SERVER)) def update_server_state(server_state, round_model_delta): return tff.federated_map(server_update_fn, (server_state, round_model_delta)) @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS), tff.FederatedType(tf.float32, tff.CLIENTS)) def compute_loss_metric(model_output, weight_denom): return tff.federated_mean(model_output, weight=weight_denom) @tff.tf_computation(model_weights_type.trainable, tf.float32) def rescale_and_remove_fn(weights_delta, weights_is): return [ tf.math.scalar_mul(weights_is, weights_layer_delta) for weights_layer_delta in weights_delta ] @tff.federated_computation( tff.FederatedType(model_weights_type.trainable, tff.CLIENTS), tff.FederatedType(tf.float32, tff.CLIENTS)) def rescale_and_remove(weights_delta, weights_is): return tff.federated_map(rescale_and_remove_fn, (weights_delta, weights_is)) @tff.federated_computation(federated_server_state_type, federated_dataset_type) def run_gradient_computation_round(server_state, federated_dataset): """Orchestration logic for one round of gradient computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.data.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `tf.Tensor` of clients initial probability and `ClientOutput`. """ server_message = tff.federated_map(server_message_fn, server_state) server_message_at_client = tff.federated_broadcast(server_message) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, server_message_at_client)) update_norm_sum_weighted = tff.federated_sum( client_outputs.update_norm_weighted) norm_sum_clients_weighted = tff.federated_broadcast( update_norm_sum_weighted) prob_init = scale_on_clients(client_outputs.update_norm_weighted, norm_sum_clients_weighted) return prob_init, client_outputs @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS)) def run_one_inner_loop_weights_computation(prob): """Orchestration logic for one round of computation. Args: prob: Probability of each client to communicate update. Returns: A tuple of updated `Probabilities` and `tf.float32` of rescaling factor. """ prob_message = create_prob_message_on_clients(prob) prob_aggreg = tff.federated_sum(prob_message) rescaling_factor_master = compute_rescaling_on_master(prob_aggreg) rescaling_factor_clients = tff.federated_broadcast( rescaling_factor_master) prob = rescale_prob_on_clients(prob, rescaling_factor_clients) return prob, rescaling_factor_master @tff.federated_computation def server_init_tff(): """Orchestration logic for server model initialization.""" return tff.federated_value(server_init_tf(), tff.SERVER) def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.data.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and `tf.Tensor` of average loss. """ prob, client_outputs = run_gradient_computation_round( server_state, federated_dataset) if importance_sampling: for j in range(j_max): prob, rescaling_factor = run_one_inner_loop_weights_computation( prob) if rescaling_factor <= 1: break weight_denom = [ client_output.client_weight for client_output in client_outputs ] weights_delta = [ client_output.weights_delta for client_output in client_outputs ] # rescale weights based on sampling procedure weights_is = compute_weights_is(prob) weights_delta = rescale_and_remove(weights_delta, weights_is) round_model_delta = compute_round_model_delta( weights_delta, weight_denom) server_state = update_server_state(server_state, round_model_delta) model_output = [ client_output.model_output for client_output in client_outputs ] round_loss_metric = compute_loss_metric(model_output, weight_denom) prob_numpy = [] for p in prob: prob_numpy.append(p.numpy()) return server_state, round_loss_metric, prob_numpy self.next = run_one_round self.initialize = server_init_tff
def create_train_dataset_preprocess_fn(vocab: List[str], client_batch_size: int, client_epochs_per_round: int, max_seq_len: int, max_training_elements_per_user: int, max_batches_per_user=-1, max_shuffle_buffer_size=10000): """Creates preprocessing functions for stackoverflow data. This function returns a function which takes a dataset and returns a dataset, generally for mapping over a set of unprocessed client datasets during training. Args: vocab: Vocabulary which defines the embedding. client_batch_size: Integer representing batch size to use on the clients. client_epochs_per_round: Number of epochs for which to repeat train client dataset. max_seq_len: Integer determining shape of padded batches. Sequences will be padded up to this length, and sentences longer than `max_seq_len` will be truncated to this length. max_training_elements_per_user: Integer controlling the maximum number of elements to take per user. If -1, takes all elements for each user. max_batches_per_user: If set to a positive integer, the maximum number of batches in each client's dataset. max_shuffle_buffer_size: Maximum shuffle buffer size. Returns: Two functions, the first `preprocess_train` and the second `preprocess_val_and_test`, as described above. """ if client_batch_size <= 0: raise ValueError( 'client_batch_size must be a positive integer; you have ' 'passed {}'.format(client_batch_size)) elif client_epochs_per_round == -1 and max_batches_per_user == -1: raise ValueError( 'Argument client_epochs_per_round is set to -1. If this is' ' intended, then max_batches_per_user must be set to ' 'some positive integer.') elif max_seq_len <= 0: raise ValueError('max_seq_len must be a positive integer; you have ' 'passed {}'.format(max_seq_len)) elif max_training_elements_per_user < -1: raise ValueError( 'max_training_elements_per_user must be an integer at ' 'least -1; you have passed {}'.format( max_training_elements_per_user)) if (max_training_elements_per_user == -1 or max_training_elements_per_user > max_shuffle_buffer_size): shuffle_buffer_size = max_shuffle_buffer_size else: shuffle_buffer_size = max_training_elements_per_user feature_dtypes = [ ('creation_date', tf.string), ('title', tf.string), ('score', tf.int64), ('tags', tf.string), ('tokens', tf.string), ('type', tf.string), ] @tff.tf_computation( tff.SequenceType( tff.NamedTupleType([(name, tff.TensorType(dtype=dtype, shape=())) for name, dtype in feature_dtypes]))) def preprocess_train(dataset): to_ids = build_to_ids_fn(vocab, max_seq_len) dataset = dataset.take(max_training_elements_per_user) if shuffle_buffer_size > 0: logging.info('Adding shuffle with buffer size: %d', shuffle_buffer_size) dataset = dataset.shuffle(shuffle_buffer_size) dataset = dataset.repeat(client_epochs_per_round) dataset = dataset.map(to_ids, num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = batch_and_split(dataset, max_seq_len, client_batch_size) return dataset.take(max_batches_per_user) return preprocess_train
# all_samples = [i for i in range(int(num*(len(source[1])/NUM_AGENT)), int((num+1)*(len(source[1])/NUM_AGENT)))] for i in range(0, len(all_samples), BATCH_SIZE): batch_samples = all_samples[i:i + BATCH_SIZE] output_sequence.append({ 'x': np.array([source[0][i].flatten() / 255.0 for i in batch_samples], dtype=np.float32), 'y': np.array([source[1][i] for i in batch_samples], dtype=np.int32) }) return output_sequence BATCH_TYPE = tff.NamedTupleType([('x', tff.TensorType(tf.float32, [None, 784])), ('y', tff.TensorType(tf.int32, [None]))]) MODEL_TYPE = tff.NamedTupleType([('weights', tff.TensorType(tf.float32, [784, 10])), ('bias', tff.TensorType(tf.float32, [10]))]) @tff.tf_computation(MODEL_TYPE, BATCH_TYPE) def batch_loss(model, batch): predicted_y = tf.nn.softmax(tf.matmul(batch.x, model.weights) + model.bias) return -tf.reduce_mean( tf.reduce_sum(tf.one_hot(batch.y, 10) * tf.log(predicted_y), axis=[1])) @tff.tf_computation(MODEL_TYPE, BATCH_TYPE, tf.float32)
def build_triehh_process(possible_prefix_extensions: List[str], num_sub_rounds: int, max_num_heavy_hitters: int, max_user_contribution: int, default_terminator: str = '$'): """Builds the TFF computations for heavy hitters discovery with TrieHH. TrieHH works by interactively keeping track of popular prefixes. In each round, the server broadcasts the popular prefixes it has discovered so far and the list of `possible_prefix_extensions` to a small fraction of selected clients. The select clients sample `max_user_contributions` words from their local datasets, and use them to vote on character extensions to the broadcasted popular prefixes. Client votes are accumulated across `num_sub_rounds` rounds, and then the top `max_num_heavy_hitters` extensions are used to extend the already discovered prefixes, and the extended prefixes are used in the next round. When an already discovered prefix is extended by `default_terminator` it is added to the list of discovered heavy hitters. Args: possible_prefix_extensions: A list containing all the possible extensions to learned prefixes. Each extensions must be a single character strings. num_sub_rounds: The total number of sub rounds to be executed before decoding aggregated votes. Must be positive. max_num_heavy_hitters: The maximum number of discoverable heavy hitters. Must be positive. max_user_contribution: The maximum number of examples a user can contribute. Must be positive. default_terminator: The end of sequence symbol. Returns: A `tff.utils.IterativeProcess`. """ @tff.tf_computation def server_init_tf(): return ServerState( discovered_heavy_hitters=tf.constant([], dtype=tf.string), discovered_prefixes=tf.constant([''], dtype=tf.string), possible_prefix_extensions=tf.constant( possible_prefix_extensions, dtype=tf.string), round_num=tf.constant(0, dtype=tf.int32), accumulated_votes=tf.zeros( dtype=tf.int32, shape=[max_num_heavy_hitters, len(possible_prefix_extensions)])) # We cannot use server_init_tf.type_signature.result because the # discovered_* fields need to have [None] shapes, since they will grow over # time. server_state_type = ( tff.to_type( ServerState( discovered_heavy_hitters=tff.TensorType( dtype=tf.string, shape=[None]), discovered_prefixes=tff.TensorType(dtype=tf.string, shape=[None]), possible_prefix_extensions=tff.TensorType( dtype=tf.string, shape=[len(possible_prefix_extensions)]), round_num=tff.TensorType(dtype=tf.int32, shape=[]), accumulated_votes=tff.TensorType( dtype=tf.int32, shape=[None, len(possible_prefix_extensions)]), ))) sub_round_votes_type = tff.TensorType( dtype=tf.int32, shape=[max_num_heavy_hitters, len(possible_prefix_extensions)]) @tff.tf_computation(server_state_type, sub_round_votes_type) @tf.function def server_update_fn(server_state, sub_round_votes): server_state = server_update( server_state, sub_round_votes, num_sub_rounds=tf.constant(num_sub_rounds), max_num_heavy_hitters=tf.constant(max_num_heavy_hitters), default_terminator=tf.constant(default_terminator, dtype=tf.string)) return server_state tf_dataset_type = tff.SequenceType(tf.string) discovered_prefixes_type = tff.TensorType(dtype=tf.string, shape=[None]) round_num_type = tff.TensorType(dtype=tf.int32, shape=[]) @tff.tf_computation(tf_dataset_type, discovered_prefixes_type, round_num_type) @tf.function def client_update_fn(tf_dataset, discovered_prefixes, round_num): result = client_update(tf_dataset, discovered_prefixes, tf.constant(possible_prefix_extensions), round_num, num_sub_rounds, max_num_heavy_hitters, max_user_contribution) return result federated_server_state_type = tff.FederatedType(server_state_type, tff.SERVER) federated_dataset_type = tff.FederatedType( tf_dataset_type, tff.CLIENTS, all_equal=False) @tff.federated_computation(federated_server_state_type, federated_dataset_type) def run_one_round(server_state, federated_dataset): """Orchestration logic for one round of TrieHH computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: An updated `ServerState` """ discovered_prefixes = tff.federated_broadcast( server_state.discovered_prefixes) round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, tff.federated_zip([federated_dataset, discovered_prefixes, round_num])) accumulated_votes = tff.federated_sum(client_outputs.client_votes) server_state = tff.federated_map(server_update_fn, (server_state, accumulated_votes)) server_output = tff.federated_value([], tff.SERVER) return server_state, server_output return tff.utils.IterativeProcess( initialize_fn=tff.federated_computation( lambda: tff.federated_value(server_init_tf(), tff.SERVER)), next_fn=run_one_round)
def build_federated_averaging_process_attacked( model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1), server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0), model_update_aggregation_factory=None, client_update_tf=ClientExplicitBoosting(boost_factor=1.0)): """Builds the TFF computations for optimization using federated averaging with potentially malicious clients. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. client_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer`, use during local client training. server_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer`, use to apply updates to the global model. model_update_aggregation_factory: An optional `tff.aggregators.AggregationFactory` that contstructs `tff.templates.AggregationProcess` for aggregating the client model updates on the server. If `None`, uses a default constructed `tff.aggregators.MeanFactory`, creating a stateless mean aggregation. client_update_tf: a 'tf.function' computes the ClientOutput. Returns: A `tff.templates.IterativeProcess`. """ with tf.Graph().as_default(): dummy_model_for_metadata = model_fn() weights_type = tff.learning.framework.weights_type_from_model( dummy_model_for_metadata) if model_update_aggregation_factory is None: model_update_aggregation_factory = tff.aggregators.MeanFactory() if isinstance(model_update_aggregation_factory, tff.aggregators.WeightedAggregationFactory): aggregation_process = model_update_aggregation_factory.create( weights_type.trainable, tff.TensorType(tf.float32)) else: aggregation_process = model_update_aggregation_factory.create( weights_type.trainable) server_init = build_server_init_fn(model_fn, server_optimizer_fn, aggregation_process.initialize) server_state_type = server_init.type_signature.result.member server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn, server_state_type, server_state_type.model) tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec) client_update_fn = build_client_update_fn(model_fn, client_optimizer_fn, client_update_tf, tf_dataset_type, server_state_type.model) federated_server_state_type = tff.type_at_server(server_state_type) federated_dataset_type = tff.type_at_clients(tf_dataset_type) run_one_round_tff = build_run_one_round_fn_attacked( server_update_fn, client_update_fn, aggregation_process, dummy_model_for_metadata, federated_server_state_type, federated_dataset_type) return tff.templates.IterativeProcess(initialize_fn=server_init, next_fn=run_one_round_tff)
beingm, this test is here to ensure that we don't break JAX support as we evolve it to be more complete and feature-full. """ import collections import itertools from absl.testing import absltest import jax import numpy as np import tensorflow_federated as tff from tensorflow_federated.python.tests import jax_components BATCH_TYPE = collections.OrderedDict([ ('pixels', tff.TensorType(np.float32, (50, 784))), ('labels', tff.TensorType(np.int32, (50, ))) ]) MODEL_TYPE = collections.OrderedDict([ ('weights', tff.TensorType(np.float32, (784, 10))), ('bias', tff.TensorType(np.float32, (10, ))) ]) def loss(model, batch): y = jax.nn.softmax( jax.numpy.add(jax.numpy.matmul(batch['pixels'], model['weights']), model['bias'])) targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10) return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))
def build_fed_avg_process( total_clients: int, effective_num_clients: int, model_fn: ModelBuilder, client_optimizer_fn: OptimizerBuilder, client_lr: Union[float, LRScheduleFn] = 0.1, server_optimizer_fn: OptimizerBuilder = tf.keras.optimizers.SGD, server_lr: Union[float, LRScheduleFn] = 1.0, client_weight_fn: Optional[ClientWeightFn] = None, aggregation_process: Optional[measured_process.MeasuredProcess] = None, ) -> tff.templates.IterativeProcess: """Builds the TFF computations for optimization using federated averaging. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. client_optimizer_fn: A function that accepts a `learning_rate` keyword argument and returns a `tf.keras.optimizers.Optimizer` instance. client_lr: A scalar learning rate or a function that accepts a float `round_num` argument and returns a learning rate. server_optimizer_fn: A function that accepts a `learning_rate` argument and returns a `tf.keras.optimizers.Optimizer` instance. server_lr: A scalar learning rate or a function that accepts a float `round_num` argument and returns a learning rate. client_weight_fn: Optional function that takes the output of `model.report_local_outputs` and returns a tensor that provides the weight in the federated average of model deltas. If not provided, the default is the total number of examples processed on device. Returns: A `tff.templates.IterativeProcess`. """ client_lr_schedule = client_lr if not callable(client_lr_schedule): client_lr_schedule = lambda round_num: client_lr server_lr_schedule = server_lr if not callable(server_lr_schedule): server_lr_schedule = lambda round_num: server_lr with tf.Graph().as_default(): dummy_model = model_fn() model_weights_type = model_utils.weights_type_from_model(dummy_model) dummy_optimizer = server_optimizer_fn() _initialize_optimizer_vars(dummy_model, dummy_optimizer) optimizer_variable_type = type_conversions.type_from_tensors( dummy_optimizer.variables()) if aggregation_process is None: aggregation_process = build_stateless_mean( model_delta_type=model_weights_type.trainable) if not _is_valid_aggregation_process(aggregation_process): raise ProcessTypeError( 'aggregation_process type signature does not conform to expected ' 'signature (<state@S, input@C> -> <state@S, result@S, measurements@S>).' ' Got: {t}'.format(t=aggregation_process.next.type_signature)) initialize_computation = build_server_init_fn( model_fn=model_fn, effective_num_clients=effective_num_clients, # Initialize with the learning rate for round zero. server_optimizer_fn=lambda: server_optimizer_fn(server_lr_schedule(0)), aggregation_process=aggregation_process) # server_state_type = initialize_computation.type_signature.result # model_weights_type = server_state_type.model round_num_type = tf.float32 tf_dataset_type = tff.SequenceType(dummy_model.input_spec) model_input_type = tff.SequenceType(dummy_model.input_spec) client_losses_at_server_type = tff.TensorType(dtype=tf.float32, shape=[total_clients, 1]) clients_weights_at_server_type = tff.TensorType(dtype=tf.float32, shape=[total_clients, 1]) aggregation_state = aggregation_process.initialize.type_signature.result.member server_state_type = ServerState( model=model_weights_type, optimizer_state=optimizer_variable_type, round_num=round_num_type, effective_num_clients=tf.int32, delta_aggregate_state=aggregation_state, ) # @computations.tf_computation(clients_weights_type) # def get_zero_weights_all_clients(weights): # return tf.zeros_like(weights, dtype=tf.float32) ###################################################### # def federated_output(local_outputs): # return federated_aggregate_keras_metric(self.get_metrics(), local_outputs) # federated_output_computation = computations.federated_computation( # federated_output, federated_local_outputs_type) single_id_type = tff.TensorType(dtype=tf.int32, shape=[1, 1]) @tff.tf_computation(model_input_type, model_weights_type, round_num_type, single_id_type) def client_update_fn(tf_dataset, initial_model_weights, round_num, client_id): client_lr = client_lr_schedule(round_num) client_optimizer = client_optimizer_fn(client_lr) client_update = create_client_update_fn() return client_update(model_fn(), tf_dataset, initial_model_weights, client_optimizer, client_id, client_weight_fn) @tff.tf_computation(server_state_type, model_weights_type.trainable) def server_update_fn(server_state, model_delta): model = model_fn() server_lr = server_lr_schedule(server_state.round_num) server_optimizer = server_optimizer_fn(server_lr) # We initialize the server optimizer variables to avoid creating them # within the scope of the tf.function server_update. _initialize_optimizer_vars(model, server_optimizer) return server_update(model, server_optimizer, server_state, model_delta) id_type = tff.TensorType(shape=[1, 1], dtype=tf.int32) @tff.tf_computation(clients_weights_at_server_type, id_type) def select_weight_fn(clients_weights, local_id): return select_weight(clients_weights, local_id) @tff.tf_computation(client_losses_at_server_type, clients_weights_at_server_type, tf.int32) def zero_small_loss_clients(losses_at_server, weights_at_server, effective_num_clients): """Receives losses and returns participating clients. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ return redefine_client_weight(losses_at_server, weights_at_server, effective_num_clients) # @tff.tf_computation(client_losses_type) # def dataset_to_tensor_fn(dataset): # return dataset_to_tensor(dataset) @tff.federated_computation(tff.FederatedType(server_state_type, tff.SERVER), tff.FederatedType(tf_dataset_type, tff.CLIENTS), tff.FederatedType(id_type, tff.CLIENTS)) def run_one_round(server_state, federated_dataset, ids): """Orchestration logic for one round of computation. Args: server_state: A `ServerState`. federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`. Returns: A tuple of updated `ServerState` and the result of `tff.learning.Model.federated_output_computation`. """ client_model = tff.federated_broadcast(server_state.model) client_round_num = tff.federated_broadcast(server_state.round_num) client_outputs = tff.federated_map( client_update_fn, (federated_dataset, client_model, client_round_num, ids)) client_weight = client_outputs.client_weight client_id = client_outputs.client_id #LOSS SELECTION: # losses_at_server = tff.federated_collect(client_outputs.model_output) # weights_at_server = tff.federated_collect(client_weight) @computations.tf_computation def zeros_fn(): return tf.zeros(shape=[total_clients, 1], dtype=tf.float32) zero = zeros_fn() at_server_type = tff.TensorType(shape=[total_clients, 1], dtype=tf.float32) # list_type = tff.SequenceType( tff.TensorType(dtype=tf.float32)) client_output_type = client_update_fn.type_signature.result @computations.tf_computation(at_server_type, client_output_type) def accumulate_weight(u, t): value = t.client_weight index = t.client_id new_u = tf.tensor_scatter_nd_update(u, index, value) return new_u @computations.tf_computation(at_server_type, client_output_type) def accumulate_loss(u, t): value = tf.reshape(tf.math.reduce_sum(t.model_output['loss']), shape=[1, 1]) index = t.client_id new_u = tf.tensor_scatter_nd_update(u, index, value) return new_u # output_at_server= tff.federated_collect(client_outputs) weights_at_server = tff.federated_reduce(client_outputs, zero, accumulate_weight) losses_at_server = tff.federated_reduce(client_outputs, zero, accumulate_loss) #losses_at_server = tff.federated_aggregate(client_outputs.model_output, zero, accumulate, merge, report) selected_clients_weights = tff.federated_map( zero_small_loss_clients, (losses_at_server, weights_at_server, server_state.effective_num_clients)) # selected_clients_weights_at_client = tff.federated_broadcast(selected_clients_weights) selected_clients_weights_broadcast = tff.federated_broadcast( selected_clients_weights) selected_clients_weights_at_client = tff.federated_map( select_weight_fn, (selected_clients_weights_broadcast, ids)) aggregation_output = aggregation_process.next( server_state.delta_aggregate_state, client_outputs.weights_delta, selected_clients_weights_at_client) # model_delta = tff.federated_mean( # client_outputs.weights_delta, weight=client_weight) server_state = tff.federated_map( server_update_fn, (server_state, aggregation_output.result)) aggregated_outputs = dummy_model.federated_output_computation( client_outputs.model_output) if aggregated_outputs.type_signature.is_struct(): aggregated_outputs = tff.federated_zip(aggregated_outputs) return server_state, aggregated_outputs # @tff.federated_computation # def initialize_fn(): # return tff.federated_value(server_init_tf(), tff.SERVER) return tff.templates.IterativeProcess(initialize_fn=initialize_computation, next_fn=run_one_round)