def test_zero_type_properties_with_zeroed_count_agg_factory( self, value_type): factory = robust.zeroing_factory( zeroing_norm=1.0, inner_agg_factory=sum_factory.SumFactory(), norm_order=2.0, zeroed_count_sum_factory=aggregator_test_utils.SumPlusOneFactory()) value_type = computation_types.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.at_server( collections.OrderedDict(zeroing_norm=(), inner_agg=(), zeroed_count_agg=tf.int32)) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_measurements_type = computation_types.at_server( collections.OrderedDict(zeroing=(), zeroing_norm=robust.NORM_TF_TYPE, zeroed_count=robust.COUNT_TF_TYPE)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_clip(self, clip_mechanism): clip_factory = clipping_factory.HistogramClippingSumFactory( clip_mechanism, 1) self.assertIsInstance(clip_factory, factory.UnweightedAggregationFactory) value_type = computation_types.to_type((tf.int32, (2,))) process = clip_factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) param_value_type = computation_types.FederatedType(value_type, placements.CLIENTS) result_value_type = computation_types.FederatedType(value_type, placements.SERVER) expected_state_type = computation_types.at_server(()) expected_measurements_type = computation_types.at_server(()) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=param_value_type), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_type_properties(self, name, value_type): factory = _hadamard_sum() if name == 'hd' else _dft_sum() value_type = computation_types.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.at_server( ((), rotation.SEED_TFF_TYPE)) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) type_test_utils.assert_types_equivalent(process.initialize.type_signature, expected_initialize_type) expected_measurements_type = computation_types.at_server( collections.OrderedDict([(name, ())])) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) type_test_utils.assert_types_equivalent(process.next.type_signature, expected_next_type)
def test_type_properties(self, value_type): factory = _discretization_sum() value_type = computation_types.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.at_server( collections.OrderedDict(scale_factor=tf.float32, prior_norm_bound=tf.float32, inner_agg_process=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) type_test_utils.assert_types_equivalent( process.initialize.type_signature, expected_initialize_type) expected_measurements_type = computation_types.at_server( collections.OrderedDict(discretize=())) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) type_test_utils.assert_types_equivalent(process.next.type_signature, expected_next_type)
def test_type_properties_simple(self, value_type, estimate_stddev): factory = _test_factory(estimate_stddev=estimate_stddev) process = factory.create(computation_types.to_type(value_type)) self.assertIsInstance(process, aggregation_process.AggregationProcess) # Inner SumFactory has no state. server_state_type = computation_types.at_server(()) expected_init_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_init_type)) expected_measurements_type = collections.OrderedDict(modclip=()) if estimate_stddev: expected_measurements_type['estimated_stddev'] = tf.float32 expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=computation_types.at_server( expected_measurements_type))) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_type_properties(self): mw_type = computation_types.to_type( model_utils.ModelWeights(trainable=(tf.float32, tf.float32), non_trainable=tf.float32)) finalizer = finalizers.build_apply_optimizer_finalizer( sgdm.build_sgdm(1.0), mw_type) self.assertIsInstance(finalizer, finalizers.FinalizerProcess) expected_param_weights_type = computation_types.at_server(mw_type) expected_param_update_type = computation_types.at_server( mw_type.trainable) expected_result_type = computation_types.at_server(mw_type) expected_state_type = computation_types.at_server(()) expected_measurements_type = computation_types.at_server(()) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) expected_initialize_type.check_equivalent_to( finalizer.initialize.type_signature) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, weights=expected_param_weights_type, update=expected_param_update_type), result=MeasuredProcessOutput(expected_state_type, expected_result_type, expected_measurements_type)) expected_next_type.check_equivalent_to(finalizer.next.type_signature)
def test_composition_type_properties(self, last_process): state_type = tf.float32 values_type = tf.int32 last_process = last_process(state_type, 0.0, values_type) composite_process = measured_process.chain_measured_processes( collections.OrderedDict( double=_create_test_measured_process_double(state_type, 1.0, values_type), last_process=last_process)) self.assertIsInstance(composite_process, measured_process.MeasuredProcess) expected_state_type = computation_types.at_server( collections.OrderedDict(double=state_type, last_process=state_type)) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( composite_process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) param_value_type = computation_types.at_clients(values_type) result_value_type = computation_types.at_server(values_type) expected_measurements_type = computation_types.at_server( collections.OrderedDict( double=collections.OrderedDict(a=tf.int32), last_process=last_process.next.type_signature.result.measurements .member)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, values=param_value_type), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type)) self.assertTrue( composite_process.next.type_signature.is_equivalent_to( expected_next_type))
def test_type_properties_simple(self): value_type = computation_types.to_type((tf.int32, (2,))) agg_factory = modular_clipping_factory.ModularClippingSumFactory( clip_range_lower=-2, clip_range_upper=2) process = agg_factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) # Inner SumFactory has no state. server_state_type = computation_types.at_server(()) expected_init_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to(expected_init_type)) expected_measurements_type = collections.OrderedDict(modclip=()) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=computation_types.at_server( expected_measurements_type))) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_fails_conflicting_binding_in_parameter_and_result(self): t1 = computation_types.FunctionType( computation_types.AbstractType('T'), computation_types.AbstractType('T')) t2 = computation_types.FunctionType(tf.int32, tf.float32) with self.assertRaises(type_analysis.UnassignableConcreteTypesError): type_analysis.check_concrete_instance_of(t2, t1)
def test_federated_secure_select(self): uri = intrinsic_defs.FEDERATED_SECURE_SELECT.uri comp = building_blocks.Intrinsic( uri, computation_types.FunctionType( [ computation_types.at_clients(tf.int32), # client_keys computation_types.at_server(tf.int32), # max_key computation_types.at_server(tf.float32), # server_state computation_types.FunctionType([tf.float32, tf.int32], tf.float32) # select_fn ], computation_types.at_clients( computation_types.SequenceType(tf.float32)))) self.assertGreater(_count_intrinsics(comp, uri), 0) # First without secure intrinsics shouldn't modify anything. reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies( comp) self.assertFalse(modified) self.assertGreater(_count_intrinsics(comp, uri), 0) self.assert_types_identical(comp.type_signature, reduced.type_signature) # Now replace bodies including secure intrinsics. reduced, modified = intrinsic_reductions.replace_secure_intrinsics_with_insecure_bodies( comp) self.assertTrue(modified) self.assert_types_identical(comp.type_signature, reduced.type_signature) self.assertGreater( _count_intrinsics(reduced, intrinsic_defs.FEDERATED_SELECT.uri), 0)
def test_type_properties_constant_bounds(self, value_type, upper_bound, lower_bound, measurements_dtype): secure_sum_f = secure.SecureSumFactory( upper_bound_threshold=upper_bound, lower_bound_threshold=lower_bound) self.assertIsInstance(secure_sum_f, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = secure_sum_f.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) expected_state_type = computation_types.at_server( computation_types.to_type(())) expected_measurements_type = _measurements_type(measurements_dtype) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=expected_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_intrinsic_class_succeeds_simple_federated_map(self): simple_function = computation_types.FunctionType(tf.int32, tf.float32) federated_arg = computation_types.FederatedType( simple_function.parameter, placements.CLIENTS) federated_result = computation_types.FederatedType( simple_function.result, placements.CLIENTS) federated_map_concrete_type = computation_types.FunctionType( computation_types.StructType((simple_function, federated_arg)), federated_result) concrete_federated_map = building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, federated_map_concrete_type) self.assertIsInstance(concrete_federated_map, building_blocks.Intrinsic) self.assertEqual( str(concrete_federated_map.type_signature), '(<(int32 -> float32),{int32}@CLIENTS> -> {float32}@CLIENTS)') self.assertEqual(concrete_federated_map.uri, 'federated_map') self.assertEqual(concrete_federated_map.compact_representation(), 'federated_map') concrete_federated_map_proto = concrete_federated_map.proto self.assertEqual( type_serialization.deserialize_type( concrete_federated_map_proto.type), concrete_federated_map.type_signature) self.assertEqual( concrete_federated_map_proto.WhichOneof('computation'), 'intrinsic') self.assertEqual(concrete_federated_map_proto.intrinsic.uri, concrete_federated_map.uri) self._serialize_deserialize_roundtrip_test(concrete_federated_map)
def test_type_properties_unweighted(self, value_type): value_type = computation_types.to_type(value_type) factory_ = mean.UnweightedMeanFactory() self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) process = factory_.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) param_value_type = computation_types.at_clients(value_type) result_value_type = computation_types.at_server(value_type) expected_state_type = computation_types.at_server(((), ())) expected_measurements_type = computation_types.at_server( collections.OrderedDict(mean_value=(), mean_count=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=param_value_type), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_repr(self): self.assertEqual( repr(computation_types.FunctionType(tf.int32, tf.bool)), 'FunctionType(TensorType(tf.int32), TensorType(tf.bool))') self.assertEqual( repr(computation_types.FunctionType(None, tf.bool)), 'FunctionType(None, TensorType(tf.bool))')
def test_roundtrip(self): add = tensorflow_computation.tf_computation(lambda x, y: x + y) server_data_type = computation_types.at_server(tf.int32) client_data_type = computation_types.at_clients(tf.int32) @federated_computation.federated_computation(server_data_type, client_data_type) def add_server_number_plus_one(server_number, client_numbers): one = intrinsics.federated_value(1, placements.SERVER) server_context = intrinsics.federated_map(add, (one, server_number)) client_context = intrinsics.federated_broadcast(server_context) return intrinsics.federated_map(add, (client_context, client_numbers)) bf = form_utils.get_broadcast_form_for_computation( add_server_number_plus_one) self.assertEqual(bf.server_data_label, 'server_number') self.assertEqual(bf.client_data_label, 'client_numbers') type_test_utils.assert_types_equivalent( bf.compute_server_context.type_signature, computation_types.FunctionType(tf.int32, (tf.int32, ))) self.assertEqual(2, bf.compute_server_context(1)[0]) type_test_utils.assert_types_equivalent( bf.client_processing.type_signature, computation_types.FunctionType(((tf.int32, ), tf.int32), tf.int32)) self.assertEqual(3, bf.client_processing((1, ), 2)) round_trip_comp = form_utils.get_computation_for_broadcast_form(bf) type_test_utils.assert_types_equivalent( round_trip_comp.type_signature, add_server_number_plus_one.type_signature) # 2 (server data) + 1 (constant in comp) + 2 (client data) = 5 (output) self.assertEqual([5, 6, 7], round_trip_comp(2, [2, 3, 4]))
def test_clip_type_properties_weighted(self, value_type, weight_type): factory = _concat_mean() value_type = computation_types.to_type(value_type) weight_type = computation_types.to_type(weight_type) process = factory.create(value_type, weight_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) # State comes from the inner MeanFactory. server_state_type = computation_types.at_server( collections.OrderedDict(value_sum_process=(), weight_sum_process=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) type_test_utils.assert_types_equivalent( process.initialize.type_signature, expected_initialize_type) # Measurements come from the inner mean factory. expected_measurements_type = computation_types.at_server( collections.OrderedDict(mean_value=(), mean_weight=())) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type), weight=computation_types.at_clients(weight_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) type_test_utils.assert_types_equivalent(process.next.type_signature, expected_next_type)
def test_type_properties(self, modulus, value_type, symmetric_range): factory_ = secure.SecureModularSumFactory( modulus=modulus, symmetric_range=symmetric_range) self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = factory_.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) expected_state_type = computation_types.at_server( computation_types.to_type(())) expected_measurements_type = expected_state_type expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=expected_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type)) try: static_assert.assert_not_contains_unsecure_aggregation( process.next) except: # pylint: disable=bare-except self.fail('Factory returned an AggregationProcess containing ' 'non-secure aggregation.')
def test_type_properties(self, value_type): sum_f = sum_factory.SumFactory() self.assertIsInstance(sum_f, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = sum_f.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) param_value_type = computation_types.FederatedType( value_type, placements.CLIENTS) result_value_type = computation_types.FederatedType( value_type, placements.SERVER) expected_state_type = computation_types.FederatedType( (), placements.SERVER) expected_measurements_type = computation_types.FederatedType( (), placements.SERVER) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict(state=expected_state_type, value=param_value_type), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_concat_type_properties_unweighted(self, value_type): factory = _concat_sum() value_type = computation_types.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) # Inner SumFactory has no state. server_state_type = computation_types.at_server(()) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) type_test_utils.assert_types_equivalent( process.initialize.type_signature, expected_initialize_type) # Inner SumFactory has no measurements. expected_measurements_type = computation_types.at_server(()) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) type_test_utils.assert_types_equivalent(process.next.type_signature, expected_next_type)
def test_type_properties(self, metric_finalizers, unfinalized_metrics): aggregate_factory = aggregation_factory.SumThenFinalizeFactory() self.assertIsInstance(aggregate_factory, factory.UnweightedAggregationFactory) local_unfinalized_metrics_type = type_conversions.type_from_tensors( unfinalized_metrics) process = aggregate_factory.create(metric_finalizers, local_unfinalized_metrics_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) expected_state_type = computation_types.FederatedType( ((), local_unfinalized_metrics_type), placements.SERVER) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) finalized_metrics_type = _get_finalized_metrics_type( metric_finalizers, unfinalized_metrics) result_value_type = computation_types.FederatedType( (finalized_metrics_type, finalized_metrics_type), placements.SERVER) measurements_type = computation_types.FederatedType((), placements.SERVER) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, unfinalized_metrics=computation_types.FederatedType( local_unfinalized_metrics_type, placements.CLIENTS)), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
class VisitPreorderTest(parameterized.TestCase): # pyformat: disable @parameterized.named_parameters([ ('abstract_type', computation_types.AbstractType('T'), 1), ('nested_function_type', computation_types.FunctionType( computation_types.FunctionType( computation_types.FunctionType(tf.int32, tf.int32), tf.int32), tf.int32), 7), ('named_tuple_type', computation_types.StructType( [tf.int32, tf.bool, computation_types.SequenceType(tf.int32)]), 5), ('placement_type', computation_types.PlacementType(), 1), ]) # pyformat: enable def test_preorder_call_count(self, type_signature, expected_count): class Counter(object): k = 0 def _count_hits(given_type, arg): del given_type # Unused. Counter.k += 1 return arg type_transformations.visit_preorder(type_signature, _count_hits, None) actual_count = Counter.k self.assertEqual(actual_count, expected_count)
def test_create_xla_tff_computation_with_reordered_tensor_indexes(self): builder = xla_client.XlaBuilder('comp') tensor_shape_1 = xla_client.Shape.array_shape( xla_client.dtype_to_etype(np.int32), (10, 1)) param_1 = xla_client.ops.Parameter(builder, 0, tensor_shape_1) tensor_shape_2 = xla_client.Shape.array_shape( xla_client.dtype_to_etype(np.int32), (1, 20)) param_2 = xla_client.ops.Parameter(builder, 1, tensor_shape_2) xla_client.ops.Dot(param_1, param_2) xla_comp = builder.build() comp_pb_1 = xla_serialization.create_xla_tff_computation( xla_comp, [0, 1], computation_types.FunctionType( ((np.int32, (10, 1)), (np.int32, (1, 20))), (np.int32, ( 10, 20, )))) self.assertIsInstance(comp_pb_1, pb.Computation) self.assertEqual(comp_pb_1.WhichOneof('computation'), 'xla') type_spec_1 = type_serialization.deserialize_type(comp_pb_1.type) self.assertEqual(str(type_spec_1), '(<int32[10,1],int32[1,20]> -> int32[10,20])') comp_pb_2 = xla_serialization.create_xla_tff_computation( xla_comp, [1, 0], computation_types.FunctionType( ((np.int32, (1, 20)), (np.int32, (10, 1))), (np.int32, ( 10, 20, )))) self.assertIsInstance(comp_pb_2, pb.Computation) self.assertEqual(comp_pb_2.WhichOneof('computation'), 'xla') type_spec_2 = type_serialization.deserialize_type(comp_pb_2.type) self.assertEqual(str(type_spec_2), '(<int32[1,20],int32[10,1]> -> int32[10,20])')
def test_clip_type_properties_simple(self, value_type): factory = _clipped_sum() value_type = computation_types.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.at_server( collections.OrderedDict(clipping_norm=(), inner_agg=(), clipped_count_agg=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_measurements_type = computation_types.at_server( collections.OrderedDict(clipping=(), clipping_norm=robust.NORM_TF_TYPE, clipped_count=robust.COUNT_TF_TYPE)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_type_properties(self, value_type): factory = stochastic_discretization.StochasticDiscretizationFactory( step_size=0.1, inner_agg_factory=_measurement_aggregator, distortion_aggregation_factory=mean.UnweightedMeanFactory()) value_type = computation_types.to_type(value_type) quantize_type = type_conversions.structure_from_tensor_type_tree( lambda x: (tf.int32, x.shape), value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.StructType([('step_size', tf.float32), ('inner_agg_process', ()) ]) server_state_type = computation_types.at_server(server_state_type) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) type_test_utils.assert_types_equivalent(process.initialize.type_signature, expected_initialize_type) expected_measurements_type = computation_types.StructType([ ('stochastic_discretization', quantize_type), ('distortion', tf.float32) ]) expected_measurements_type = computation_types.at_server( expected_measurements_type) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) type_test_utils.assert_types_equivalent(process.next.type_signature, expected_next_type)
def test_roundtrip_no_broadcast(self): add_five = tensorflow_computation.tf_computation(lambda x: x + 5) server_data_type = computation_types.at_server(()) client_data_type = computation_types.at_clients(tf.int32) @federated_computation.federated_computation(server_data_type, client_data_type) def add_five_at_clients(naught_at_server, client_numbers): del naught_at_server return intrinsics.federated_map(add_five, client_numbers) bf = form_utils.get_broadcast_form_for_computation(add_five_at_clients) self.assertEqual(bf.server_data_label, 'naught_at_server') self.assertEqual(bf.client_data_label, 'client_numbers') type_test_utils.assert_types_equivalent( bf.compute_server_context.type_signature, computation_types.FunctionType((), ())) type_test_utils.assert_types_equivalent( bf.client_processing.type_signature, computation_types.FunctionType(((), tf.int32), tf.int32)) self.assertEqual(6, bf.client_processing((), 1)) round_trip_comp = form_utils.get_computation_for_broadcast_form(bf) type_test_utils.assert_types_equivalent( round_trip_comp.type_signature, add_five_at_clients.type_signature) self.assertEqual([10, 11, 12], round_trip_comp((), [5, 6, 7]))
def test_equality(self): t1 = computation_types.FunctionType(tf.int32, tf.bool) t2 = computation_types.FunctionType(tf.int32, tf.bool) t3 = computation_types.FunctionType(tf.int32, tf.int32) t4 = computation_types.FunctionType(tf.bool, tf.bool) self.assertEqual(t1, t2) self.assertNotEqual(t1, t3) self.assertNotEqual(t1, t4)
def test_returns_true_for_calls_with_no_arguments(self): function_type_1 = computation_types.FunctionType(None, tf.int32) fn_1 = building_blocks.Reference('a', function_type_1) comp_1 = building_blocks.Call(fn_1) function_type_2 = computation_types.FunctionType(None, tf.int32) fn_2 = building_blocks.Reference('a', function_type_2) comp_2 = building_blocks.Call(fn_2) self.assertTrue(tree_analysis.trees_equal(comp_1, comp_2))
class FnToBuildingBlockTest(parameterized.TestCase): # pyformat: disable @parameterized.named_parameters(( 'nested_fn_same', lambda f, x: f(f(x)), computation_types.StructType( [('f', computation_types.FunctionType(tf.int32, tf.int32)), ('x', tf.int32)]), '(FEDERATED_foo -> (let fc_FEDERATED_symbol_0=FEDERATED_foo.f(FEDERATED_foo.x),fc_FEDERATED_symbol_1=FEDERATED_foo.f(fc_FEDERATED_symbol_0) in fc_FEDERATED_symbol_1))' ), ('nested_fn_different', lambda f, g, x: f(g(x)), computation_types.StructType( [('f', computation_types.FunctionType(tf.int32, tf.int32)), ('g', computation_types.FunctionType(tf.int32, tf.int32)), ('x', tf.int32)]), '(FEDERATED_foo -> (let fc_FEDERATED_symbol_0=FEDERATED_foo.g(FEDERATED_foo.x),fc_FEDERATED_symbol_1=FEDERATED_foo.f(fc_FEDERATED_symbol_0) in fc_FEDERATED_symbol_1))' ), ('selection', lambda x: (x[1], x[0]), computation_types.StructType([tf.int32, tf.int32]), '(FEDERATED_foo -> <FEDERATED_foo[1],FEDERATED_foo[0]>)'), ('constant', lambda: 'stuff', None, '( -> (let fc_FEDERATED_symbol_0=comp#')) # pyformat: enable def test_returns_result(self, fn, parameter_type, fn_str): parameter_name = 'foo' if parameter_type is not None else None result, _ = _federated_computation_serializer(fn, parameter_name, parameter_type) self.assertStartsWith(str(result), fn_str) # pyformat: disable @parameterized.named_parameters( ('tuple', lambda x: (x[1], x[0]), computation_types.StructType([tf.int32, tf.float32]), computation_types.StructWithPythonType([(None, tf.float32), (None, tf.int32)], tuple)), ('list', lambda x: [x[1], x[0]], computation_types.StructType([tf.int32, tf.float32]), computation_types.StructWithPythonType([(None, tf.float32), (None, tf.int32)], list)), ('odict', lambda x: collections.OrderedDict([('A', x[1]), ('B', x[0])]), computation_types.StructType([tf.int32, tf.float32]), computation_types.StructWithPythonType([('A', tf.float32), ('B', tf.int32)], collections.OrderedDict)), ('namedtuple', lambda x: TestNamedTuple(x=x[1], y=x[0]), computation_types.StructType([tf.int32, tf.float32]), computation_types.StructWithPythonType([('x', tf.float32), ('y', tf.int32)], TestNamedTuple)), ) # pyformat: enable def test_returns_result_with_py_container(self, fn, parameter_type, exepcted_result_type): _, type_signature = _federated_computation_serializer( fn, 'foo', parameter_type) self.assertIs(type(type_signature.result), type(exepcted_result_type)) self.assertIs(type_signature.result.python_container, exepcted_result_type.python_container) self.assertEqual(type_signature.result, exepcted_result_type)
def test_construction(self, weighted): aggregation_factory = (mean.MeanFactory() if weighted else sum_factory.SumFactory()) iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_examples.LinearRegression, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=tf.keras.optimizers.SGD, model_update_aggregation_factory=aggregation_factory) if weighted: aggregate_state = collections.OrderedDict(value_sum_process=(), weight_sum_process=()) aggregate_metrics = collections.OrderedDict(mean_value=(), mean_weight=()) else: aggregate_state = () aggregate_metrics = () server_state_type = computation_types.FederatedType( optimizer_utils.ServerState(model=model_utils.ModelWeights( trainable=[ computation_types.TensorType(tf.float32, [2, 1]), computation_types.TensorType(tf.float32) ], non_trainable=[computation_types.TensorType(tf.float32)]), optimizer_state=[tf.int64], delta_aggregate_state=aggregate_state, model_broadcast_state=()), placements.SERVER) self.assert_types_equivalent( computation_types.FunctionType(parameter=None, result=server_state_type), iterative_process.initialize.type_signature) dataset_type = computation_types.FederatedType( computation_types.SequenceType( collections.OrderedDict( x=computation_types.TensorType(tf.float32, [None, 2]), y=computation_types.TensorType(tf.float32, [None, 1]))), placements.CLIENTS) metrics_type = computation_types.FederatedType( collections.OrderedDict( broadcast=(), aggregation=aggregate_metrics, train=collections.OrderedDict( loss=computation_types.TensorType(tf.float32), num_examples=computation_types.TensorType(tf.int32)), stat=collections.OrderedDict( num_examples=computation_types.TensorType(tf.float32))), placements.SERVER) self.assert_types_equivalent( computation_types.FunctionType(parameter=collections.OrderedDict( server_state=server_state_type, federated_dataset=dataset_type, ), result=(server_state_type, metrics_type)), iterative_process.next.type_signature)
def test_passes_with_federated_map(self): intrinsic = building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, computation_types.FunctionType([ computation_types.FunctionType(tf.int32, tf.float32), computation_types.FederatedType(tf.int32, placements.CLIENTS) ], computation_types.FederatedType(tf.float32, placements.CLIENTS))) tree_analysis.check_contains_only_reducible_intrinsics(intrinsic)