def test_clip_type_properties_simple(self, value_type): factory = _clipped_sum() value_type = tff.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, tff.templates.AggregationProcess) server_state_type = tff.type_at_server( ()) # Inner SumFactory has no state expected_initialize_type = tff.FunctionType(parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_measurements_type = tff.type_at_server( collections.OrderedDict(agg_process=())) expected_next_type = tff.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=tff.type_at_clients(value_type)), result=tff.templates.MeasuredProcessOutput( state=server_state_type, result=tff.type_at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_iterative_process_type_signature(self): client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) lr_callback_type = tff.framework.type_from_tensors(client_lr_callback) server_state_type = tff.FederatedType( adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights( trainable=(tff.TensorType(tf.float32, [1, 1]), tff.TensorType(tf.float32, [1])), non_trainable=()), optimizer_state=[tf.int64], client_lr_callback=lr_callback_type, server_lr_callback=lr_callback_type), tff.SERVER) self.assertEqual( iterative_process.initialize.type_signature, tff.FunctionType(parameter=None, result=server_state_type)) dataset_type = tff.FederatedType( tff.SequenceType( collections.OrderedDict( x=tff.TensorType(tf.float32, [None, 1]), y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS) metrics_type = tff.FederatedType( collections.OrderedDict(loss=tff.TensorType(tf.float32)), tff.SERVER) output_type = collections.OrderedDict(before_training=metrics_type, during_training=metrics_type) expected_result_type = (server_state_type, output_type) expected_type = tff.FunctionType(parameter=collections.OrderedDict( server_state=server_state_type, federated_dataset=dataset_type), result=expected_result_type) actual_type = iterative_process.next.type_signature self.assertEqual(actual_type, expected_type, msg='{s}\n!={t}'.format(s=actual_type, t=expected_type))
def test_iterative_process_type_signature(self): iterative_process = decay_iterative_process_builder.from_flags( input_spec=get_input_spec(), model_builder=model_builder, loss_builder=loss_builder, metrics_builder=metrics_builder) dummy_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=FLAGS.client_learning_rate, decay_factor=FLAGS.client_decay_factor, min_delta=FLAGS.min_delta, min_lr=FLAGS.min_lr, window_size=FLAGS.window_size, patience=FLAGS.patience) lr_callback_type = tff.framework.type_from_tensors(dummy_lr_callback) server_state_type = tff.FederatedType( adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights( trainable=(tff.TensorType(tf.float32, [1, 1]), tff.TensorType(tf.float32, [1])), non_trainable=()), optimizer_state=[tf.int64], client_lr_callback=lr_callback_type, server_lr_callback=lr_callback_type), tff.SERVER) self.assertEqual( iterative_process.initialize.type_signature, tff.FunctionType(parameter=None, result=server_state_type)) dataset_type = tff.FederatedType( tff.SequenceType( collections.OrderedDict( x=tff.TensorType(tf.float32, [None, 1]), y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS) metrics_type = tff.FederatedType( collections.OrderedDict( mean_squared_error=tff.TensorType(tf.float32), loss=tff.TensorType(tf.float32)), tff.SERVER) output_type = collections.OrderedDict(before_training=metrics_type, during_training=metrics_type) expected_result_type = (server_state_type, output_type) expected_type = tff.FunctionType(parameter=(server_state_type, dataset_type), result=expected_result_type) actual_type = iterative_process.next.type_signature self.assertTrue(actual_type.is_equivalent_to(expected_type))
def test_eval_fn_has_correct_type_signature(self): metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()] eval_fn = evaluation.build_centralized_evaluation( tff_model_fn, metrics_builder) actual_type = eval_fn.type_signature model_type = tff.FederatedType( tff.learning.ModelWeights( trainable=( tff.TensorType(tf.float32, [1, 1]), tff.TensorType(tf.float32, [1]), ), non_trainable=(), ), tff.SERVER) dataset_type = tff.FederatedType( tff.SequenceType( collections.OrderedDict( x=tff.TensorType(tf.float32, [None, 1]), y=tff.TensorType(tf.float32, [None, 1]))), tff.SERVER) metrics_type = tff.FederatedType( collections.OrderedDict( mean_squared_error=tff.TensorType(tf.float32), num_examples=tff.TensorType(tf.float32)), tff.SERVER) expected_type = tff.FunctionType(parameter=collections.OrderedDict( model_weights=model_type, centralized_dataset=dataset_type), result=metrics_type) actual_type.check_assignable_from(expected_type)
def test_federated_mean_masked(self): value_type = tff.StructType([('a', tff.TensorType(tf.float32, shape=[3])), ('b', tff.TensorType(tf.float32, shape=[2, 1]))]) weight_type = tff.TensorType(tf.float32) federated_mean_masked_fn = fed_pa_schedule.build_federated_mean_masked( value_type, weight_type) # Check type signature. expected_type = tff.FunctionType( parameter=collections.OrderedDict( value=tff.FederatedType(value_type, tff.CLIENTS), weight=tff.FederatedType(weight_type, tff.CLIENTS)), result=tff.FederatedType(value_type, tff.SERVER)) self.assertTrue( federated_mean_masked_fn.type_signature.is_equivalent_to(expected_type), msg='{s}\n!={t}'.format( s=federated_mean_masked_fn.type_signature, t=expected_type)) # Check correctness of zero masking in the mean. values = [ collections.OrderedDict( a=tf.constant([0.0, 1.0, 1.0]), b=tf.constant([[1.0], [2.0]])), collections.OrderedDict( a=tf.constant([1.0, 3.0, 0.0]), b=tf.constant([[3.0], [0.0]])) ] weights = [tf.constant(1.0), tf.constant(3.0)] output = federated_mean_masked_fn(values, weights) expected_output = collections.OrderedDict( a=tf.constant([1.0, 2.5, 1.0]), b=tf.constant([[2.5], [2.0]])) self.assertAllClose(output, expected_output)
def test_types(self): trainer = create_trainer(batch_size=100, step_size=0.01) model_type = trainer.create_initial_model.type_signature.result example_batch = next(trainer.generate_random_batches(1)) make_example_batch = tff.experimental.jax_computation( lambda: example_batch) batch_type = make_example_batch.type_signature.result self.assertEqual( str(trainer.train_on_one_batch.type_signature), str( tff.FunctionType( collections.OrderedDict([('model', model_type), ('batch', batch_type)]), model_type))) self.assertEqual( str(trainer.compute_loss_on_one_batch.type_signature), str( tff.FunctionType( collections.OrderedDict([('model', model_type), ('batch', batch_type)]), np.float32)))
def test_build_with_preprocess_function(self): test_dataset = tf.data.Dataset.range(5) client_datasets_type = tff.FederatedType( tff.SequenceType(test_dataset.element_spec), tff.CLIENTS) @tff.tf_computation(tff.SequenceType(test_dataset.element_spec)) def preprocess_dataset(ds): def to_batch(x): return _Batch( tf.fill(dims=(784,), value=float(x) * 2.0), tf.expand_dims(tf.cast(x + 1, dtype=tf.int64), axis=0)) return ds.map(to_batch).batch(2) iterproc = _build_simple_fed_pa_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD) iterproc = tff.simulation.compose_dataset_computation_with_iterative_process( preprocess_dataset, iterproc) with tf.Graph().as_default(): test_model_for_types = _uncompiled_model_builder() server_state_type = tff.FederatedType( fed_pa_schedule.ServerState( model=tff.framework.type_from_tensors( tff.learning.ModelWeights( test_model_for_types.trainable_variables, test_model_for_types.non_trainable_variables)), optimizer_state=(tf.int64,), round_num=tf.float32), tff.SERVER) metrics_type = tff.FederatedType( tff.StructType([('loss', tf.float32), ('model_delta_zeros_percent', tf.float32), ('model_delta_correction_l2_norm', tf.float32)]), tff.SERVER) expected_parameter_type = collections.OrderedDict( server_state=server_state_type, federated_dataset=client_datasets_type, ) expected_result_type = (server_state_type, metrics_type) expected_type = tff.FunctionType( parameter=expected_parameter_type, result=expected_result_type) self.assertTrue( iterproc.next.type_signature.is_equivalent_to(expected_type), msg='{s}\n!={t}'.format( s=iterproc.next.type_signature, t=expected_type))
def test_build_with_preprocess_function(self): test_dataset = tf.data.Dataset.range(5) client_datasets_type = tff.type_at_clients( tff.SequenceType(test_dataset.element_spec)) @tff.tf_computation(tff.SequenceType(test_dataset.element_spec)) def preprocess_dataset(ds): def create_batch(x): return collections.OrderedDict( x=[tf.cast(x, dtype=tf.float32)], y=[2.0]) return ds.map(create_batch).batch(2) iterproc = fed_avg_schedule.build_fed_avg_process( model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, client_lr=0.01, server_optimizer_fn=tf.keras.optimizers.SGD) iterproc = tff.simulation.compose_dataset_computation_with_iterative_process( preprocess_dataset, iterproc) with tf.Graph().as_default(): test_model_for_types = model_builder() server_state_type = tff.FederatedType( fed_avg_schedule.ServerState(model=tff.framework.type_from_tensors( tff.learning.ModelWeights( test_model_for_types.trainable_variables, test_model_for_types.non_trainable_variables)), optimizer_state=(tf.int64, ), round_num=tf.float32), tff.SERVER) metrics_type = test_model_for_types.federated_output_computation.type_signature.result expected_parameter_type = collections.OrderedDict( server_state=server_state_type, federated_dataset=client_datasets_type, ) expected_result_type = (server_state_type, metrics_type) expected_type = tff.FunctionType(parameter=expected_parameter_type, result=expected_result_type) self.assertTrue( iterproc.next.type_signature.is_equivalent_to(expected_type), msg='{s}\n!={t}'.format(s=iterproc.next.type_signature, t=expected_type))
def test_clip_by_global_norm(self): clip_norm = 20.0 test_deltas = [ create_weights_delta(), create_weights_delta(constant=10) ] update_type = tff.framework.type_from_tensors(test_deltas[0]) aggregate_fn = aggregate_fns.build_fixed_clip_norm_mean_process( clip_norm=clip_norm, model_update_type=update_type) self.assertEqual( aggregate_fn.next.type_signature, tff.FunctionType( parameter=collections.OrderedDict( state=tff.FederatedType((), tff.SERVER), deltas=tff.FederatedType(update_type, tff.CLIENTS), weights=tff.FederatedType(tf.float32, tff.CLIENTS), ), result=collections.OrderedDict( state=tff.FederatedType((), tff.SERVER), result=tff.FederatedType(update_type, tff.SERVER), measurements=tff.FederatedType( aggregate_fns.NormClippedAggregationMetrics( max_global_norm=tf.float32, num_clipped=tf.int32), tff.SERVER)), )) state = aggregate_fn.initialize() weights = [1., 1.] output = aggregate_fn.next(state, test_deltas, weights) expected_clipped = [] for delta in test_deltas: clipped, _ = tf.clip_by_global_norm(tf.nest.flatten(delta), clip_norm) expected_clipped.append(tf.nest.pack_sequence_as(delta, clipped)) expected_mean = tf.nest.map_structure(lambda a, b: (a + b) / 2, *expected_clipped) self.assertAllClose(expected_mean, output['result']) # Global l2 norms [17.74824, 53.99074]. metrics = output['measurements'] self.assertAlmostEqual(metrics.max_global_norm, 53.99074, places=5) self.assertEqual(metrics.num_clipped, 1)
def test_build_with_preprocess_function(self): test_dataset = tf.data.Dataset.range(5) client_datasets_type = tff.FederatedType( tff.SequenceType(test_dataset.element_spec), tff.CLIENTS) @tff.tf_computation(tff.SequenceType(test_dataset.element_spec)) def preprocess_dataset(ds): def to_batch(x): return _Batch( tf.fill(dims=(784,), value=float(x) * 2.0), tf.expand_dims(tf.cast(x + 1, dtype=tf.int64), axis=0)) return ds.map(to_batch).batch(2) iterproc_adapter = fed_avg_schedule.build_fed_avg_process( _uncompiled_model_builder, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD, dataset_preprocess_comp=preprocess_dataset) with tf.Graph().as_default(): test_model_for_types = _uncompiled_model_builder() iterproc = iterproc_adapter._iterative_process server_state_type = tff.FederatedType( fed_avg_schedule.ServerState( model=tff.framework.type_from_tensors( fed_avg_schedule.ModelWeights( test_model_for_types.trainable_variables, test_model_for_types.non_trainable_variables)), optimizer_state=(tf.int64,), round_num=tf.float32), tff.SERVER) metrics_type = test_model_for_types.federated_output_computation.type_signature.result expected_type = tff.FunctionType( parameter=(server_state_type, client_datasets_type), result=(server_state_type, metrics_type)) self.assertEqual( iterproc.next.type_signature, expected_type, msg='{s}\n!={t}'.format( s=iterproc.next.type_signature, t=expected_type))
def test_build_with_preprocess_funtion(self): test_dataset = tf.data.Dataset.range(5) client_datasets_type = tff.FederatedType( tff.SequenceType(test_dataset.element_spec), tff.CLIENTS) @tff.tf_computation(tff.SequenceType(test_dataset.element_spec)) def preprocess_dataset(ds): def to_batch(x): return collections.OrderedDict(x=[float(x) * 1.0], y=[float(x) * 3.0 + 1.0]) return ds.map(to_batch).repeat().batch(2).take(3) client_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) server_lr_callback = callbacks.create_reduce_lr_on_plateau( learning_rate=0.1, min_delta=0.5, window_size=2, decay_factor=1.0, cooldown=0) iterative_process = adaptive_fed_avg.build_fed_avg_process( _uncompiled_model_builder, client_lr_callback, server_lr_callback, client_optimizer_fn=tf.keras.optimizers.SGD, server_optimizer_fn=tf.keras.optimizers.SGD, dataset_preprocess_comp=preprocess_dataset) lr_callback_type = tff.framework.type_from_tensors(client_lr_callback) server_state_type = tff.FederatedType( adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights( trainable=(tff.TensorType(tf.float32, [1, 1]), tff.TensorType(tf.float32, [1])), non_trainable=()), optimizer_state=[tf.int64], client_lr_callback=lr_callback_type, server_lr_callback=lr_callback_type), tff.SERVER) self.assertEqual( iterative_process.initialize.type_signature, tff.FunctionType(parameter=None, result=server_state_type)) metrics_type = tff.FederatedType( collections.OrderedDict(loss=tff.TensorType(tf.float32)), tff.SERVER) output_type = collections.OrderedDict(before_training=metrics_type, during_training=metrics_type) expected_result_type = (server_state_type, output_type) expected_type = tff.FunctionType(parameter=collections.OrderedDict( server_state=server_state_type, federated_dataset=client_datasets_type), result=expected_result_type) actual_type = iterative_process.next.type_signature self.assertEqual(actual_type, expected_type, msg='{s}\n!={t}'.format(s=actual_type, t=expected_type))