def test_type_properties(self, value_type, factory_cons): factory = factory_cons() value_type = computation_types.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.at_server( ((), collections.OrderedDict(value_sum_process=(), weight_sum_process=()))) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_measurements_type = computation_types.at_server( collections.OrderedDict(value_sum_process=(), weight_sum_process=())) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type), weight=computation_types.at_clients(tf.float32)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_concat_type_properties_unweighted(self, value_type): factory = _concat_sum() value_type = computation_types.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) # Inner SumFactory has no state. server_state_type = computation_types.at_server(()) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) type_test_utils.assert_types_equivalent( process.initialize.type_signature, expected_initialize_type) # Inner SumFactory has no measurements. expected_measurements_type = computation_types.at_server(()) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) type_test_utils.assert_types_equivalent(process.next.type_signature, expected_next_type)
def next_fn(state, value): clip_lower = intrinsics.federated_value(self._clip_range_lower, placements.SERVER) clip_upper = intrinsics.federated_value(self._clip_range_upper, placements.SERVER) # Modular clip values before aggregation. clipped_value = intrinsics.federated_map( modular_clip_by_value_fn, (value, intrinsics.federated_broadcast(clip_lower), intrinsics.federated_broadcast(clip_upper))) inner_agg_output = inner_agg_next(state, clipped_value) # Clip the aggregate to the same range again (not considering summands). clipped_agg_output_result = intrinsics.federated_map( modular_clip_by_value_fn, (inner_agg_output.result, clip_lower, clip_upper)) measurements = collections.OrderedDict( modclip=inner_agg_output.measurements) if self._estimate_stddev: estimate = intrinsics.federated_map( estimator_fn, (clipped_agg_output_result, clip_lower, clip_upper)) measurements['estimated_stddev'] = estimate return measured_process.MeasuredProcessOutput( state=inner_agg_output.state, result=clipped_agg_output_result, measurements=intrinsics.federated_zip(measurements))
def test_type_properties_constant_bounds(self, value_type, upper_bound, lower_bound, measurements_dtype): secure_sum_f = secure_factory.SecureSumFactory( upper_bound_threshold=upper_bound, lower_bound_threshold=lower_bound) self.assertIsInstance(secure_sum_f, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = secure_sum_f.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) expected_state_type = computation_types.at_server( computation_types.to_type(())) expected_measurements_type = _measurements_type(measurements_dtype) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=expected_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def next_fn(state, value): server_scale_factor = state['scale_factor'] client_scale_factor = intrinsics.federated_broadcast( server_scale_factor) server_prior_norm_bound = state['prior_norm_bound'] prior_norm_bound = intrinsics.federated_broadcast( server_prior_norm_bound) discretized_value = intrinsics.federated_map( discretize_fn, (value, client_scale_factor, prior_norm_bound)) inner_state = state['inner_agg_process'] inner_agg_output = inner_agg_process.next(inner_state, discretized_value) undiscretized_agg_value = intrinsics.federated_map( undiscretize_fn, (inner_agg_output.result, server_scale_factor)) new_state = collections.OrderedDict( scale_factor=server_scale_factor, prior_norm_bound=server_prior_norm_bound, inner_agg_process=inner_agg_output.state) measurements = collections.OrderedDict( discretize=inner_agg_output.measurements) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=undiscretized_agg_value, measurements=intrinsics.federated_zip(measurements))
def next_fn(state, value): server_step_size = state['step_size'] client_step_size = intrinsics.federated_broadcast(server_step_size) discretized_value = intrinsics.federated_map(discretize_fn, (value, client_step_size)) inner_state = state['inner_agg_process'] inner_agg_output = inner_agg_process.next(inner_state, discretized_value) undiscretized_agg_value = intrinsics.federated_map( undiscretize_fn, (inner_agg_output.result, server_step_size)) new_state = collections.OrderedDict( step_size=server_step_size, inner_agg_process=inner_agg_output.state) measurements = collections.OrderedDict( deterministic_discretization=inner_agg_output.measurements) if self._distortion_aggregation_factory is not None: distortions = intrinsics.federated_map(distortion_measurement_fn, (value, client_step_size)) aggregate_distortion = distortion_aggregation_process.next( distortion_aggregation_process.initialize(), distortions).result measurements['distortion'] = aggregate_distortion return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=undiscretized_agg_value, measurements=intrinsics.federated_zip(measurements))
def test_type_properties(self, value_type): sum_f = test_utils.SumPlusOneFactory() self.assertIsInstance(sum_f, factory.AggregationProcessFactory) value_type = computation_types.to_type(value_type) process = sum_f.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) param_value_type = computation_types.FederatedType( value_type, placements.CLIENTS) result_value_type = computation_types.FederatedType( value_type, placements.SERVER) expected_state_type = computation_types.FederatedType( tf.int32, placements.SERVER) expected_measurements_type = computation_types.FederatedType( tf.int32, placements.SERVER) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict(state=expected_state_type, value=param_value_type), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_type_properties(self, value_type, weight_type): mean_f = mean_factory.MeanFactory() self.assertIsInstance(mean_f, factory.WeightedAggregationFactory) value_type = computation_types.to_type(value_type) weight_type = computation_types.to_type(weight_type) process = mean_f.create_weighted(value_type, weight_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) param_value_type = computation_types.FederatedType( value_type, placements.CLIENTS) result_value_type = computation_types.FederatedType( value_type, placements.SERVER) expected_state_type = computation_types.at_server( collections.OrderedDict(value_sum_process=(), weight_sum_process=())) expected_measurements_type = computation_types.at_server( collections.OrderedDict(value_sum_process=(), weight_sum_process=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=param_value_type, weight=computation_types.at_clients(weight_type)), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def stateful_broadcast(state, value): test_metrics = intrinsics.federated_value( 3.0, placements.SERVER) return measured_process_lib.MeasuredProcessOutput( state=state, result=intrinsics.federated_broadcast(value), measurements=test_metrics)
def test_non_server_placed_next_result_raises(self): init_fn = computations.federated_computation( lambda: intrinsics.federated_value(0, placements.SERVER)) next_fn = computations.federated_computation(SERVER_INT, CLIENTS_INT)( lambda x, y: measured_process.MeasuredProcessOutput(x, y, x)) with self.assertRaises(TypeError): aggregation_process.AggregationProcess(init_fn, next_fn)
def test_type_properties(self, value_type): factory = stochastic_discretization.StochasticDiscretizationFactory( step_size=0.1, inner_agg_factory=_measurement_aggregator, distortion_aggregation_factory=mean.UnweightedMeanFactory()) value_type = computation_types.to_type(value_type) quantize_type = type_conversions.structure_from_tensor_type_tree( lambda x: (tf.int32, x.shape), value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.StructType([('step_size', tf.float32), ('inner_agg_process', ()) ]) server_state_type = computation_types.at_server(server_state_type) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) type_test_utils.assert_types_equivalent(process.initialize.type_signature, expected_initialize_type) expected_measurements_type = computation_types.StructType([ ('stochastic_discretization', quantize_type), ('distortion', tf.float32) ]) expected_measurements_type = computation_types.at_server( expected_measurements_type) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) type_test_utils.assert_types_equivalent(process.next.type_signature, expected_next_type)
def test_single_param_next_raises(self): init_fn = computations.federated_computation( lambda: intrinsics.federated_value(0, placements.SERVER)) next_fn = computations.federated_computation(SERVER_INT)( lambda x: measured_process.MeasuredProcessOutput(x, x, x)) with self.assertRaises(TypeError): aggregation_process.AggregationProcess(init_fn, next_fn)
def test_type_properties_simple(self, value_type, estimate_stddev): factory = _test_factory(estimate_stddev=estimate_stddev) process = factory.create(computation_types.to_type(value_type)) self.assertIsInstance(process, aggregation_process.AggregationProcess) # Inner SumFactory has no state. server_state_type = computation_types.at_server(()) expected_init_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_init_type)) expected_measurements_type = collections.OrderedDict(modclip=()) if estimate_stddev: expected_measurements_type['estimated_stddev'] = tf.float32 expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=computation_types.at_server( expected_measurements_type))) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_clip_type_properties_with_clipped_count_agg_factory( self, value_type): factory = robust.clipping_factory( clipping_norm=1.0, inner_agg_factory=sum_factory.SumFactory(), clipped_count_sum_factory=aggregator_test_utils.SumPlusOneFactory( )) value_type = computation_types.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.at_server( collections.OrderedDict(clipping_norm=(), inner_agg=(), clipped_count_agg=tf.int32)) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_measurements_type = computation_types.at_server( collections.OrderedDict(clipping=(), clipping_norm=robust.NORM_TF_TYPE, clipped_count=robust.COUNT_TF_TYPE)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_type_properties(self, metric_finalizers, unfinalized_metrics): aggregate_factory = aggregation_factory.SumThenFinalizeFactory() self.assertIsInstance(aggregate_factory, factory.UnweightedAggregationFactory) local_unfinalized_metrics_type = type_conversions.type_from_tensors( unfinalized_metrics) process = aggregate_factory.create(metric_finalizers, local_unfinalized_metrics_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) expected_state_type = computation_types.FederatedType( ((), local_unfinalized_metrics_type), placements.SERVER) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) finalized_metrics_type = _get_finalized_metrics_type( metric_finalizers, unfinalized_metrics) result_value_type = computation_types.FederatedType( (finalized_metrics_type, finalized_metrics_type), placements.SERVER) measurements_type = computation_types.FederatedType((), placements.SERVER) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, unfinalized_metrics=computation_types.FederatedType( local_unfinalized_metrics_type, placements.CLIENTS)), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def next_comp(state): return intrinsics.federated_zip( measured_process.MeasuredProcessOutput( state=state, result=intrinsics.federated_value(0, placements.SERVER), measurements=intrinsics.federated_value( (), placements.SERVER))),
def next_fn_impl(state, value, clip_fn, inner_agg_process, weight=None): clipping_norm_state, agg_state, clipped_count_state = state clipping_norm = clipping_norm_process.report(clipping_norm_state) clipped_value, global_norm, was_clipped = intrinsics.federated_map( clip_fn, (value, intrinsics.federated_broadcast(clipping_norm))) new_clipping_norm_state = clipping_norm_process.next( clipping_norm_state, global_norm) if weight is None: agg_output = inner_agg_process.next(agg_state, clipped_value) else: agg_output = inner_agg_process.next(agg_state, clipped_value, weight) clipped_count_output = clipped_count_agg_process.next( clipped_count_state, was_clipped) new_state = collections.OrderedDict([ (prefix('ing_norm'), new_clipping_norm_state), ('inner_agg', agg_output.state), (prefix('ed_count_agg'), clipped_count_output.state) ]) measurements = collections.OrderedDict([ (prefix('ing'), agg_output.measurements), (prefix('ing_norm'), clipping_norm), (prefix('ed_count'), clipped_count_output.result) ]) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=agg_output.result, measurements=intrinsics.federated_zip(measurements))
def next_fn(state, unfinalized_metrics) -> measured_process.MeasuredProcessOutput: inner_summation_state, unfinalized_metrics_accumulators = state inner_summation_output = inner_summation_process.next( inner_summation_state, unfinalized_metrics) summed_unfinalized_metrics = inner_summation_output.result inner_summation_state = inner_summation_output.state @tensorflow_computation.tf_computation(local_unfinalized_metrics_type, local_unfinalized_metrics_type) def add_unfinalized_metrics(unfinalized_metrics, summed_unfinalized_metrics): return tf.nest.map_structure(tf.add, unfinalized_metrics, summed_unfinalized_metrics) unfinalized_metrics_accumulators = intrinsics.federated_map( add_unfinalized_metrics, (unfinalized_metrics_accumulators, summed_unfinalized_metrics)) finalizer_computation = _build_finalizer_computation( metric_finalizers, local_unfinalized_metrics_type) current_round_metrics = intrinsics.federated_map( finalizer_computation, summed_unfinalized_metrics) total_rounds_metrics = intrinsics.federated_map( finalizer_computation, unfinalized_metrics_accumulators) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip( (inner_summation_state, unfinalized_metrics_accumulators)), result=intrinsics.federated_zip( (current_round_metrics, total_rounds_metrics)), measurements=inner_summation_output.measurements)
def test_zero_type_properties_weighted(self, value_type, weight_type): factory = _zeroed_mean() value_type = computation_types.to_type(value_type) weight_type = computation_types.to_type(weight_type) process = factory.create(value_type, weight_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) mean_state_type = collections.OrderedDict(value_sum_process=(), weight_sum_process=()) server_state_type = computation_types.at_server( collections.OrderedDict(zeroing_norm=(), inner_agg=mean_state_type, zeroed_count_agg=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_measurements_type = computation_types.at_server( collections.OrderedDict(zeroing=collections.OrderedDict( mean_value=(), mean_weight=()), zeroing_norm=robust_factory.NORM_TF_TYPE, zeroed_count=robust_factory.COUNT_TF_TYPE)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type), weight=computation_types.at_clients(weight_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_type_properties_simple(self): value_type = computation_types.to_type((tf.int32, (2,))) agg_factory = modular_clipping_factory.ModularClippingSumFactory( clip_range_lower=-2, clip_range_upper=2) process = agg_factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) # Inner SumFactory has no state. server_state_type = computation_types.at_server(()) expected_init_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to(expected_init_type)) expected_measurements_type = collections.OrderedDict(modclip=()) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=computation_types.at_server( expected_measurements_type))) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def next_fn(state, value, weight): # Client computation. weighted_value = intrinsics.federated_map(_mul, (value, weight)) # Inner aggregations. value_output = value_sum_process.next(state['value_sum_process'], weighted_value) weight_output = weight_sum_process.next(state['weight_sum_process'], weight) # Server computation. weighted_mean_value = intrinsics.federated_map( _div_no_nan if self._no_nan_division else _div, (value_output.result, weight_output.result)) # Output preparation. state = collections.OrderedDict( value_sum_process=value_output.state, weight_sum_process=weight_output.state) measurements = collections.OrderedDict( value_sum_process=value_output.measurements, weight_sum_process=weight_output.measurements) return measured_process.MeasuredProcessOutput( intrinsics.federated_zip(state), weighted_mean_value, intrinsics.federated_zip(measurements))
def next_fn(global_state, value, weight): """Defines next_fn for MeasuredProcess.""" # Weighted aggregation is not supported. # TODO(b/140236959): Add an assertion that weight is None here, so the # contract of this method is better established. Will likely cause some # downstream breaks. del weight sample_params = intrinsics.federated_map(derive_sample_params, global_state) client_sample_params = intrinsics.federated_broadcast(sample_params) preprocessed_record = intrinsics.federated_map( preprocess_record, (client_sample_params, value)) agg_result = intrinsics.federated_aggregate(preprocessed_record, zero(), accumulate, merge, report) updated_state, result = intrinsics.federated_map( post_process, (agg_result, global_state)) metrics = intrinsics.federated_map(derive_metrics, updated_state) return measured_process.MeasuredProcessOutput(state=updated_state, result=result, measurements=metrics)
def test_clip_type_properties_simple(self, value_type): factory = _clipped_sum() value_type = computation_types.to_type(value_type) process = factory.create_unweighted(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.at_server( collections.OrderedDict( clipping_norm=(), inner_agg=(), clipped_count_agg=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_measurements_type = computation_types.at_server( collections.OrderedDict( agg_process=(), clipping_norm=clipping_factory.NORM_TF_TYPE, clipped_count=clipping_factory.COUNT_TF_TYPE)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_type_properties_unweighted(self, value_type): value_type = computation_types.to_type(value_type) factory_ = mean.UnweightedMeanFactory() self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) process = factory_.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) param_value_type = computation_types.at_clients(value_type) result_value_type = computation_types.at_server(value_type) expected_state_type = computation_types.at_server(()) expected_measurements_type = computation_types.at_server( collections.OrderedDict(mean_value=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=param_value_type), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def next_fn(state, value): encode_params, decode_before_sum_params, decode_after_sum_params = ( intrinsics.federated_map(get_params_fn, state)) encode_params = intrinsics.federated_broadcast(encode_params) decode_before_sum_params = intrinsics.federated_broadcast( decode_before_sum_params) encoded_values = intrinsics.federated_map( encode_fn, [value, encode_params, decode_before_sum_params]) aggregated_values = intrinsics.federated_aggregate( encoded_values, zero_fn(), accumulate_fn, merge_fn, report_fn) decoded_values = intrinsics.federated_map( decode_after_sum_fn, [aggregated_values.values, decode_after_sum_params]) updated_state = intrinsics.federated_map( update_state_fn, [state, aggregated_values.state_update_tensors]) empty_metrics = intrinsics.federated_value((), placements.SERVER) return measured_process.MeasuredProcessOutput( state=updated_state, result=decoded_values, measurements=empty_metrics)
def next_fn_impl(state, value, weight=None): zeroing_norm_state, agg_state, zeroed_count_state = state zeroing_norm = self._zeroing_norm_process.report( zeroing_norm_state) zeroed_value, norm, was_zeroed = intrinsics.federated_map( zero_fn, (value, intrinsics.federated_broadcast(zeroing_norm))) new_zeroing_norm_state = self._zeroing_norm_process.next( zeroing_norm_state, norm) if weight is None: agg_output = inner_agg_next(agg_state, zeroed_value) else: agg_output = inner_agg_next(agg_state, zeroed_value, weight) zeroed_count_output = zeroed_count_agg_next( zeroed_count_state, was_zeroed) new_state = collections.OrderedDict( zeroing_norm=new_zeroing_norm_state, inner_agg=agg_output.state, zeroed_count_agg=zeroed_count_output.state) measurements = collections.OrderedDict( agg_process=agg_output.measurements, zeroing_norm=zeroing_norm, zeroed_count=zeroed_count_output.result) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=agg_output.result, measurements=intrinsics.federated_zip(measurements))
def test_clip_type_properties_weighted(self, value_type, weight_type): factory = _concat_mean() value_type = computation_types.to_type(value_type) weight_type = computation_types.to_type(weight_type) process = factory.create(value_type, weight_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) # State comes from the inner MeanFactory. server_state_type = computation_types.at_server( collections.OrderedDict(value_sum_process=(), weight_sum_process=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) type_test_utils.assert_types_equivalent( process.initialize.type_signature, expected_initialize_type) # Measurements come from the inner mean factory. expected_measurements_type = computation_types.at_server( collections.OrderedDict(mean_value=(), mean_weight=())) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type), weight=computation_types.at_clients(weight_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) type_test_utils.assert_types_equivalent(process.next.type_signature, expected_next_type)
def test_type_properties(self, modulus, value_type, symmetric_range): factory_ = secure.SecureModularSumFactory( modulus=modulus, symmetric_range=symmetric_range) self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = factory_.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) expected_state_type = computation_types.at_server( computation_types.to_type(())) expected_measurements_type = expected_state_type expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=expected_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type)) try: static_assert.assert_not_contains_unsecure_aggregation( process.next) except: # pylint: disable=bare-except self.fail('Factory returned an AggregationProcess containing ' 'non-secure aggregation.')
def next_comp(state, value, weight): return measured_process.MeasuredProcessOutput( state=intrinsics.federated_map(_add_one, state), result=intrinsics.federated_mean(value, weight), measurements=intrinsics.federated_zip( collections.OrderedDict(num_clients=intrinsics.federated_sum( intrinsics.federated_value(1, placements.CLIENTS)))))
def test_clip(self, clip_mechanism): clip_factory = clipping_factory.HistogramClippingSumFactory( clip_mechanism, 1) self.assertIsInstance(clip_factory, factory.UnweightedAggregationFactory) value_type = computation_types.to_type((tf.int32, (2,))) process = clip_factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) param_value_type = computation_types.FederatedType(value_type, placements.CLIENTS) result_value_type = computation_types.FederatedType(value_type, placements.SERVER) expected_state_type = computation_types.at_server(()) expected_measurements_type = computation_types.at_server(()) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=param_value_type), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))