def _build_test_map_reduce_form_with_computations(initialize=None, prepare=None, work=None, zero=None, accumulate=None, merge=None, report=None, bitwidth=None, max_input=None, modulus=None, update=None): (test_initialize, test_prepare, test_work, test_zero, test_accumulate, test_merge, test_report, test_bitwidth, test_max_input, test_modulus, test_update) = _test_map_reduce_form_computations() return forms.MapReduceForm( initialize if initialize else test_initialize, prepare if prepare else test_prepare, work if work else test_work, zero if zero else test_zero, accumulate if accumulate else test_accumulate, merge if merge else test_merge, report if report else test_report, bitwidth if bitwidth else test_bitwidth, max_input if max_input else test_max_input, modulus if modulus else test_modulus, update if update else test_update, )
def get_map_reduce_form_for_iterative_process( ip: iterative_process.IterativeProcess, grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG ) -> forms.MapReduceForm: """Constructs `tff.backends.mapreduce.MapReduceForm` given iterative process. Args: ip: An instance of `tff.templates.IterativeProcess` that is compatible with MapReduce form. Iterative processes are only compatible if `initialize_fn` returns a single federated value placed at `SERVER` and `next` takes exactly two arguments. The first must be the state value placed at `SERVER`. - `next` returns exactly two values. grappler_config: An optional instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization of the TensorFlow graphs backing the resulting `tff.backends.mapreduce.MapReduceForm`. These options are combined with a set of defaults that aggressively configure Grappler. If the input `grappler_config` has `graph_options.rewrite_options.disable_meta_optimizer=True`, Grappler is bypassed. Returns: An instance of `tff.backends.mapreduce.MapReduceForm` equivalent to the provided `tff.templates.IterativeProcess`. Raises: TypeError: If the arguments are of the wrong types. transformations.MapReduceFormCompilationError: If the compilation process fails. """ py_typecheck.check_type(ip, iterative_process.IterativeProcess) initialize_bb, next_bb = ( check_iterative_process_compatible_with_map_reduce_form(ip)) py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto) grappler_config = _merge_grappler_config_with_default(grappler_config) next_bb, _ = tree_transformations.uniquify_reference_names(next_bb) before_broadcast, after_broadcast = _split_ast_on_broadcast(next_bb) before_aggregate, after_aggregate = _split_ast_on_aggregate( after_broadcast) initialize = transformations.consolidate_and_extract_local_processing( initialize_bb, grappler_config) prepare = _extract_prepare(before_broadcast, grappler_config) work = _extract_work(before_aggregate, grappler_config) zero, accumulate, merge, report = _extract_federated_aggregate_functions( before_aggregate, grappler_config) bitwidth = _extract_federated_secure_sum_bitwidth_functions( before_aggregate, grappler_config) update = _extract_update(after_aggregate, grappler_config) next_parameter_names = structure.name_list_with_nones( ip.next.type_signature.parameter) server_state_label, client_data_label = next_parameter_names comps = (computation_wrapper_instances.building_block_to_computation(bb) for bb in (initialize, prepare, work, zero, accumulate, merge, report, bitwidth, update)) return forms.MapReduceForm(*comps, server_state_label=server_state_label, client_data_label=client_data_label)
def test_init_does_not_raise_type_error(self): (initialize, prepare, work, zero, accumulate, merge, report, bitwidth, update) = _test_map_reduce_form_computations() try: forms.MapReduceForm(initialize, prepare, work, zero, accumulate, merge, report, bitwidth, update) except TypeError: self.fail('Raised TypeError unexpectedly.')
def test_init_raises_type_error_with_bad_zero_result_type(self): (initialize, prepare, work, _, accumulate, merge, report, bitwidth, update) = _test_map_reduce_form_computations() @computations.tf_computation def zero(): return tf.constant(0.0), tf.constant(0) with self.assertRaises(TypeError): forms.MapReduceForm(initialize, prepare, work, zero, accumulate, merge, report, bitwidth, update)
def test_init_raises_type_error_with_bad_report_result_type(self): (initialize, prepare, work, zero, accumulate, merge, _, bitwidth, update) = _test_map_reduce_form_computations() @computations.tf_computation(tf.int32, tf.int32) def report(accumulator): del accumulator # Unused return tf.constant(1) with self.assertRaises(TypeError): forms.MapReduceForm(initialize, prepare, work, zero, accumulate, merge, report, bitwidth, update)
def test_init_raises_type_error_with_bad_prepare_parameter_type(self): (initialize, _, work, zero, accumulate, merge, report, bitwidth, update) = _test_map_reduce_form_computations() @computations.tf_computation(tf.float32) def prepare(server_state): del server_state # Unused return tf.constant(1.0) with self.assertRaises(TypeError): forms.MapReduceForm(initialize, prepare, work, zero, accumulate, merge, report, bitwidth, update)
def test_init_raises_type_error_with_bad_accumulate_result_type(self): (initialize, prepare, work, zero, _, merge, report, bitwidth, update) = _test_map_reduce_form_computations() @computations.tf_computation((tf.float32, tf.float32), tf.bool) def accumulate(accumulator, client_update): del accumulator # Unused del client_update # Unused return tf.constant(1.0), tf.constant(1) with self.assertRaises(TypeError): forms.MapReduceForm(initialize, prepare, work, zero, accumulate, merge, report, bitwidth, update)
def test_init_raises_type_error_with_bad_update_result_type(self): (initialize, prepare, work, zero, accumulate, merge, report, bitwidth, _) = _test_map_reduce_form_computations() @computations.tf_computation( tf.int32, (tf.float32, computation_types.StructType([]))) def update(server_state, global_update): del server_state # Unused del global_update # Unused return tf.constant(1.0), [] with self.assertRaises(TypeError): forms.MapReduceForm(initialize, prepare, work, zero, accumulate, merge, report, bitwidth, update)
def test_init_raises_type_error_with_bad_merge_second_parameter_type(self): (initialize, prepare, work, zero, accumulate, _, report, bitwidth, update) = _test_map_reduce_form_computations() @computations.tf_computation((tf.int32, tf.int32), (tf.float32, tf.int32)) def merge(accumulator1, accumulator2): del accumulator1 # Unused del accumulator2 # Unused return tf.constant(1), tf.constant(1) with self.assertRaises(TypeError): forms.MapReduceForm(initialize, prepare, work, zero, accumulate, merge, report, bitwidth, update)
def test_init_raises_type_error_with_bad_work_result_type(self): (initialize, prepare, _, zero, accumulate, merge, report, bitwidth, update) = _test_map_reduce_form_computations() @computations.tf_computation(computation_types.SequenceType( tf.float32), tf.float32) def work(client_data, client_input): del client_data # Unused del client_input # Unused return tf.constant('abc'), [] with self.assertRaises(TypeError): forms.MapReduceForm(initialize, prepare, work, zero, accumulate, merge, report, bitwidth, update)
def test_init_does_not_raise_type_error_with_unknown_dimensions(self): server_state_type = computation_types.TensorType(shape=[None], dtype=tf.int32) @tensorflow_computation.tf_computation def initialize(): # Return a value of a type assignable to, but not equal to # `server_state_type` return tf.constant([1, 2, 3]) @tensorflow_computation.tf_computation(server_state_type) def prepare(server_state): del server_state # Unused return tf.constant(1.0) @tensorflow_computation.tf_computation( computation_types.SequenceType(tf.float32), tf.float32) def work(client_data, client_input): del client_data # Unused del client_input # Unused return True, [], [], [] @tensorflow_computation.tf_computation def zero(): return tf.constant([], dtype=tf.string) @tensorflow_computation.tf_computation( computation_types.TensorType(shape=[None], dtype=tf.string), tf.bool) def accumulate(accumulator, client_update): del accumulator # Unused del client_update # Unused return tf.constant(['abc']) @tensorflow_computation.tf_computation( computation_types.TensorType(shape=[None], dtype=tf.string), computation_types.TensorType(shape=[None], dtype=tf.string)) def merge(accumulator1, accumulator2): del accumulator1 # Unused del accumulator2 # Unused return tf.constant(['abc']) @tensorflow_computation.tf_computation( computation_types.TensorType(shape=[None], dtype=tf.string)) def report(accumulator): del accumulator # Unused return tf.constant(1.0) unit_comp = tensorflow_computation.tf_computation(lambda: []) bitwidth = unit_comp max_input = unit_comp modulus = unit_comp unit_type = computation_types.to_type([]) @tensorflow_computation.tf_computation( server_state_type, (tf.float32, unit_type, unit_type, unit_type)) def update(server_state, global_update): del server_state # Unused del global_update # Unused # Return a new server state value whose type is assignable but not equal # to `server_state_type`, and which is different from the type returned # by `initialize`. return tf.constant([1]), [] try: forms.MapReduceForm(initialize, prepare, work, zero, accumulate, merge, report, bitwidth, max_input, modulus, update) except TypeError: self.fail('Raised TypeError unexpectedly.')
def get_temperature_sensor_example(): """Constructs `forms.MapReduceForm` for temperature sensors example. The temperature sensor example computes the fraction of sensors that report temperatures over the threshold. Returns: An instance of `forms.MapReduceForm`. """ @tensorflow_computation.tf_computation def initialize(): return collections.OrderedDict(num_rounds=tf.constant(0)) # The state of the server is a singleton tuple containing just the integer # counter `num_rounds`. server_state_type = collections.OrderedDict(num_rounds=tf.int32) @tensorflow_computation.tf_computation(server_state_type) def prepare(state): return collections.OrderedDict( max_temperature=32.0 + tf.cast(state['num_rounds'], tf.float32)) # The initial state of the client is a singleton tuple containing a single # float `max_temperature`, which is the threshold received from the server. client_state_type = collections.OrderedDict(max_temperature=tf.float32) # The client data is a sequence of floats. client_data_type = computation_types.SequenceType(tf.float32) @tensorflow_computation.tf_computation(client_data_type, client_state_type) def work(data, state): """See the `forms.MapReduceForm` definition of `work`.""" def fn(s, x): return { 'num': s['num'] + 1, 'max': tf.maximum(s['max'], x), } reduce_result = data.reduce( { 'num': np.int32(0), 'max': np.float32(-459.67) }, fn) client_updates = collections.OrderedDict( is_over=reduce_result['max'] > state['max_temperature']) return client_updates, [], [], [] # The client update is a singleton tuple with a Boolean-typed `is_over`. client_update_type = collections.OrderedDict(is_over=tf.bool) # The accumulator for client updates is a pair of counters, one for the # number of clients over threshold, and the other for the total number of # client updates processed so far. accumulator_type = collections.OrderedDict(num_total=tf.int32, num_over=tf.int32) @tensorflow_computation.tf_computation def zero(): return collections.OrderedDict(num_total=tf.constant(0), num_over=tf.constant(0)) @tensorflow_computation.tf_computation(accumulator_type, client_update_type) def accumulate(accumulator, update): return collections.OrderedDict(num_total=accumulator['num_total'] + 1, num_over=accumulator['num_over'] + tf.cast(update['is_over'], tf.int32)) @tensorflow_computation.tf_computation(accumulator_type, accumulator_type) def merge(accumulator1, accumulator2): return collections.OrderedDict( num_total=accumulator1['num_total'] + accumulator2['num_total'], num_over=accumulator1['num_over'] + accumulator2['num_over']) @tensorflow_computation.tf_computation(merge.type_signature.result) def report(accumulator): return collections.OrderedDict(ratio_over_threshold=( tf.cast(accumulator['num_over'], tf.float32) / tf.cast(accumulator['num_total'], tf.float32))) unit_comp = tensorflow_computation.tf_computation(lambda: []) bitwidth = unit_comp max_input = unit_comp modulus = unit_comp update_type = (collections.OrderedDict(ratio_over_threshold=tf.float32), (), (), ()) @tensorflow_computation.tf_computation(server_state_type, update_type) def update(state, update): return (collections.OrderedDict(num_rounds=state['num_rounds'] + 1), update[0]) return forms.MapReduceForm(initialize, prepare, work, zero, accumulate, merge, report, bitwidth, max_input, modulus, update)
def get_mnist_training_example(): """Constructs `forms.MapReduceForm` for mnist training. Returns: An instance of `forms.MapReduceForm`. """ model_nt = collections.namedtuple('Model', 'weights bias') server_state_nt = (collections.namedtuple('ServerState', 'model num_rounds')) # Start with a model filled with zeros, and the round counter set to zero. @tensorflow_computation.tf_computation def initialize(): return server_state_nt(model=model_nt(weights=tf.zeros([784, 10]), bias=tf.zeros([10])), num_rounds=tf.constant(0)) server_state_tff_type = server_state_nt(model=model_nt(weights=(tf.float32, [784, 10]), bias=(tf.float32, [10])), num_rounds=tf.int32) client_state_nt = (collections.namedtuple('ClientState', 'model learning_rate')) # Pass the model to the client, along with a dynamically adjusted learning # rate that starts at 0.1 and decays exponentially by a factor of 0.9. @tensorflow_computation.tf_computation(server_state_tff_type) def prepare(state): learning_rate = 0.1 * tf.pow(0.9, tf.cast(state.num_rounds, tf.float32)) return client_state_nt(model=state.model, learning_rate=learning_rate) batch_nt = collections.namedtuple('Batch', 'x y') batch_tff_type = batch_nt(x=(tf.float32, [None, 784]), y=(tf.int32, [None])) dataset_tff_type = computation_types.SequenceType(batch_tff_type) model_tff_type = model_nt(weights=(tf.float32, [784, 10]), bias=(tf.float32, [10])) client_state_tff_type = client_state_nt(model=model_tff_type, learning_rate=tf.float32) loop_state_nt = collections.namedtuple('LoopState', 'num_examples total_loss') update_nt = collections.namedtuple('Update', 'model num_examples loss') # Train the model locally, emit the loclaly-trained model and the number of # examples as an update, and the average loss and the number of examples as # local client stats. @tensorflow_computation.tf_computation(dataset_tff_type, client_state_tff_type) def work(data, state): # pylint: disable=missing-docstring model_vars = model_nt(weights=tf.Variable( initial_value=state.model.weights, name='weights'), bias=tf.Variable(initial_value=state.model.bias, name='bias')) init_model = tf.compat.v1.global_variables_initializer() optimizer = tf.keras.optimizers.SGD(state.learning_rate) @tf.function def reduce_fn(loop_state, batch): """Compute a single gradient step on an given batch of examples.""" with tf.GradientTape() as tape: pred_y = tf.nn.softmax( tf.matmul(batch.x, model_vars.weights) + model_vars.bias) loss = -tf.reduce_mean( tf.reduce_sum( tf.one_hot(batch.y, 10) * tf.math.log(pred_y), axis=[1])) grads = tape.gradient(loss, model_vars) optimizer.apply_gradients( zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars))) return loop_state_nt(num_examples=loop_state.num_examples + 1, total_loss=loop_state.total_loss + loss) with tf.control_dependencies([init_model]): loop_state = data.reduce( loop_state_nt(num_examples=0, total_loss=np.float32(0.0)), reduce_fn) num_examples = loop_state.num_examples total_loss = loop_state.total_loss with tf.control_dependencies([num_examples, total_loss]): loss = total_loss / tf.cast(num_examples, tf.float32) return update_nt(model=model_vars, num_examples=num_examples, loss=loss), [], [], [] accumulator_nt = update_nt # Initialize accumulators for aggregation with zero model and zero examples. @tensorflow_computation.tf_computation def zero(): return accumulator_nt(model=model_nt(weights=tf.zeros([784, 10]), bias=tf.zeros([10])), num_examples=tf.constant(0), loss=tf.constant(0.0, dtype=tf.float32)) update_tff_type = update_nt(model=model_tff_type, num_examples=tf.int32, loss=tf.float32) accumulator_tff_type = update_tff_type # We add an update to an accumulator with the update's model multipled by the # number of examples, so we can compute a weighted average in the end. @tensorflow_computation.tf_computation(accumulator_tff_type, update_tff_type) def accumulate(accumulator, update): scaling_factor = tf.cast(update.num_examples, tf.float32) scaled_model = tf.nest.map_structure(lambda x: x * scaling_factor, update.model) return accumulator_nt( model=tf.nest.map_structure(tf.add, accumulator.model, scaled_model), num_examples=accumulator.num_examples + update.num_examples, loss=accumulator.loss + update.loss * scaling_factor) # Merging accumulators does not involve scaling. @tensorflow_computation.tf_computation(accumulator_tff_type, accumulator_tff_type) def merge(accumulator1, accumulator2): return accumulator_nt( model=tf.nest.map_structure(tf.add, accumulator1.model, accumulator2.model), num_examples=accumulator1.num_examples + accumulator2.num_examples, loss=accumulator1.loss + accumulator2.loss) report_nt = accumulator_nt # The result of aggregation is produced by dividing the accumulated model by # the total number of examples. Same for loss. @tensorflow_computation.tf_computation(accumulator_tff_type) def report(accumulator): scaling_factor = 1.0 / tf.cast(accumulator.num_examples, tf.float32) scaled_model = model_nt(weights=accumulator.model.weights * scaling_factor, bias=accumulator.model.bias * scaling_factor) return report_nt(model=scaled_model, num_examples=accumulator.num_examples, loss=accumulator.loss * scaling_factor) unit_computation = tensorflow_computation.tf_computation(lambda: []) secure_sum_bitwidth = unit_computation secure_sum_max_input = unit_computation secure_sum_modulus = unit_computation update_type = (accumulator_tff_type, (), (), ()) metrics_nt = collections.namedtuple('Metrics', 'num_rounds num_examples loss') # Pass the newly averaged model along with an incremented round counter over # to the next round, and output the counters and loss as server metrics. @tensorflow_computation.tf_computation(server_state_tff_type, update_type) def update(state, update): report = update[0] num_rounds = state.num_rounds + 1 return (server_state_nt(model=report.model, num_rounds=num_rounds), metrics_nt(num_rounds=num_rounds, num_examples=report.num_examples, loss=report.loss)) return forms.MapReduceForm(initialize, prepare, work, zero, accumulate, merge, report, secure_sum_bitwidth, secure_sum_max_input, secure_sum_modulus, update)
def get_federated_sum_example(*, secure_sum: bool = False) -> forms.MapReduceForm: """Constructs `forms.MapReduceForm` which performs a sum aggregation. Args: secure_sum: Whether to use `federated_secure_sum_bitwidth`. Defaults to `federated_sum`. Returns: An instance of `forms.MapReduceForm`. """ @tensorflow_computation.tf_computation def initialize(): return () server_state_type = initialize.type_signature.result @tensorflow_computation.tf_computation(server_state_type) def prepare(state): return state @tensorflow_computation.tf_computation( computation_types.SequenceType(tf.int32), prepare.type_signature.result) def work(data, _): client_sum = data.reduce(initial_state=0, reduce_func=tf.add) if secure_sum: return [], client_sum, [], [] else: return client_sum, [], [], [] @tensorflow_computation.tf_computation def zero(): if secure_sum: return () else: return 0 client_update_type = work.type_signature.result[0] accumulator_type = zero.type_signature.result @tensorflow_computation.tf_computation(accumulator_type, client_update_type) def accumulate(accumulator, update): if secure_sum: return () else: return accumulator + update @tensorflow_computation.tf_computation(accumulator_type, accumulator_type) def merge(accumulator1, accumulator2): if secure_sum: return () else: return accumulator1 + accumulator2 @tensorflow_computation.tf_computation(merge.type_signature.result) def report(accumulator): return accumulator bitwidth = tensorflow_computation.tf_computation(lambda: 32) max_input = tensorflow_computation.tf_computation(lambda: 0) modulus = tensorflow_computation.tf_computation(lambda: 0) update_type = ( merge.type_signature.result, work.type_signature.result[1], work.type_signature.result[2], work.type_signature.result[3], ) @tensorflow_computation.tf_computation(server_state_type, update_type) def update(state, update): if secure_sum: return state, update[1] else: return state, update[0] return forms.MapReduceForm(initialize, prepare, work, zero, accumulate, merge, report, bitwidth, max_input, modulus, update)