Пример #1
0
    def test_zero_type_properties_with_zeroed_count_agg_factory(
            self, value_type):
        factory = robust.zeroing_factory(
            zeroing_norm=1.0,
            inner_agg_factory=sum_factory.SumFactory(),
            norm_order=2.0,
            zeroed_count_sum_factory=aggregator_test_utils.SumPlusOneFactory())
        value_type = computation_types.to_type(value_type)
        process = factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        server_state_type = computation_types.at_server(
            collections.OrderedDict(zeroing_norm=(),
                                    inner_agg=(),
                                    zeroed_count_agg=tf.int32))
        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(zeroing=(),
                                    zeroing_norm=robust.NORM_TF_TYPE,
                                    zeroed_count=robust.COUNT_TF_TYPE))
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
Пример #2
0
    def test_custom_model_zeroing_clipping_aggregator_factory(self):
        client_data = create_emnist_client_data()
        train_data = [client_data(), client_data()]

        def loss_fn():
            return tf.keras.losses.SparseCategoricalCrossentropy()

        def metrics_fn():
            return [
                counters.NumExamplesCounter(),
                counters.NumBatchesCounter(),
                tf.keras.metrics.SparseCategoricalAccuracy()
            ]

        # No values should be clipped and zeroed
        aggregation_factory = robust.zeroing_factory(
            zeroing_norm=float('inf'), inner_agg_factory=mean.MeanFactory())

        # Disable reconstruction via 0 learning rate to ensure post-recon loss
        # matches exact expectations round 0 and decreases by the next round.
        trainer = training_process.build_training_process(
            MnistModel,
            loss_fn=loss_fn,
            metrics_fn=metrics_fn,
            server_optimizer_fn=_get_keras_optimizer_fn(0.01),
            client_optimizer_fn=_get_keras_optimizer_fn(0.001),
            reconstruction_optimizer_fn=_get_keras_optimizer_fn(0.0),
            aggregation_factory=aggregation_factory,
            dataset_split_fn=reconstruction_utils.simple_dataset_split_fn)
        state = trainer.initialize()

        outputs = []
        states = []
        for _ in range(2):
            state, output = trainer.next(state, train_data)
            outputs.append(output)
            states.append(state)

        # All weights and biases are initialized to 0, so initial logits are all 0
        # and softmax probabilities are uniform over 10 classes. So negative log
        # likelihood is -ln(1/10). This is on expectation, so increase tolerance.
        self.assertAllClose(outputs[0]['train']['loss'],
                            tf.math.log(10.0),
                            rtol=1e-4)
        self.assertLess(outputs[1]['train']['loss'],
                        outputs[0]['train']['loss'])
        self.assertNotAllClose(states[0].model.trainable,
                               states[1].model.trainable)

        # Expect 6 reconstruction examples, 6 training examples. Only training
        # included in metrics.
        self.assertEqual(outputs[0]['train']['num_examples'], 6.0)
        self.assertEqual(outputs[1]['train']['num_examples'], 6.0)

        # Expect 4 reconstruction batches and 4 training batches. Only training
        # included in metrics.
        self.assertEqual(outputs[0]['train']['num_batches'], 4.0)
        self.assertEqual(outputs[1]['train']['num_batches'], 4.0)
def _default_zeroing(
        inner_factory: factory.AggregationFactory,
        secure_estimation: bool = False) -> factory.AggregationFactory:
    """The default adaptive zeroing wrapper."""

    # Adapts very quickly to a value somewhat higher than the highest values so
    # far seen.
    zeroing_norm = quantile_estimation.PrivateQuantileEstimationProcess.no_noise(
        initial_estimate=10.0,
        target_quantile=0.98,
        learning_rate=math.log(10.0),
        multiplier=2.0,
        increment=1.0,
        secure_estimation=secure_estimation)
    if secure_estimation:
        secure_count_factory = secure.SecureSumFactory(upper_bound_threshold=1,
                                                       lower_bound_threshold=0)
        return robust.zeroing_factory(
            zeroing_norm,
            inner_factory,
            zeroed_count_sum_factory=secure_count_factory)
    else:
        return robust.zeroing_factory(zeroing_norm, inner_factory)
Пример #4
0
def _default_zeroing(
        inner_factory: factory.AggregationFactory
) -> factory.AggregationFactory:
    """The default adaptive zeroing wrapper."""

    # Adapts very quickly to a value somewhat higher than the highest values so
    # far seen.
    zeroing_norm = quantile_estimation.PrivateQuantileEstimationProcess.no_noise(
        initial_estimate=10.0,
        target_quantile=0.98,
        learning_rate=math.log(10.0),
        multiplier=2.0,
        increment=1.0)
    return robust.zeroing_factory(zeroing_norm, inner_factory)
Пример #5
0
    def test_increasing_zero_clip_sum(self):
        # Tests when zeroing and clipping are performed with non-integer clips.
        # Zeroing norm grows by 0.75 each time, clipping norm grows by 0.25.

        @computations.federated_computation(_float_at_server,
                                            _float_at_clients)
        def zeroing_next_fn(state, value):
            del value
            return intrinsics.federated_map(
                computations.tf_computation(lambda x: x + 0.75, tf.float32),
                state)

        @computations.federated_computation(_float_at_server,
                                            _float_at_clients)
        def clipping_next_fn(state, value):
            del value
            return intrinsics.federated_map(
                computations.tf_computation(lambda x: x + 0.25, tf.float32),
                state)

        zeroing_norm_process = estimation_process.EstimationProcess(
            _test_init_fn, zeroing_next_fn, _test_report_fn)
        clipping_norm_process = estimation_process.EstimationProcess(
            _test_init_fn, clipping_next_fn, _test_report_fn)

        factory = robust.zeroing_factory(zeroing_norm_process,
                                         _clipped_sum(clipping_norm_process))

        value_type = computation_types.to_type(tf.float32)
        process = factory.create(value_type)

        state = process.initialize()

        client_data = [1.0, 2.0, 3.0]
        output = process.next(state, client_data)
        self.assertAllClose(1.0, output.measurements['zeroing_norm'])
        self.assertAllClose(1.0,
                            output.measurements['zeroing']['clipping_norm'])
        self.assertEqual(2, output.measurements['zeroed_count'])
        self.assertEqual(0, output.measurements['zeroing']['clipped_count'])
        self.assertAllClose(1.0, output.result)

        output = process.next(output.state, client_data)
        self.assertAllClose(1.75, output.measurements['zeroing_norm'])
        self.assertAllClose(1.25,
                            output.measurements['zeroing']['clipping_norm'])
        self.assertEqual(2, output.measurements['zeroed_count'])
        self.assertEqual(0, output.measurements['zeroing']['clipped_count'])
        self.assertAllClose(1.0, output.result)

        output = process.next(output.state, client_data)
        self.assertAllClose(2.5, output.measurements['zeroing_norm'])
        self.assertAllClose(1.5,
                            output.measurements['zeroing']['clipping_norm'])
        self.assertEqual(1, output.measurements['zeroed_count'])
        self.assertEqual(1, output.measurements['zeroing']['clipped_count'])
        self.assertAllClose(2.5, output.result)

        output = process.next(output.state, client_data)
        self.assertAllClose(3.25, output.measurements['zeroing_norm'])
        self.assertAllClose(1.75,
                            output.measurements['zeroing']['clipping_norm'])
        self.assertEqual(0, output.measurements['zeroed_count'])
        self.assertEqual(2, output.measurements['zeroing']['clipped_count'])
        self.assertAllClose(4.5, output.result)

        output = process.next(output.state, client_data)
        self.assertAllClose(4.0, output.measurements['zeroing_norm'])
        self.assertAllClose(2.0,
                            output.measurements['zeroing']['clipping_norm'])
        self.assertEqual(0, output.measurements['zeroed_count'])
        self.assertEqual(1, output.measurements['zeroing']['clipped_count'])
        self.assertAllClose(5.0, output.result)
Пример #6
0
def _zeroed_sum(clip=2.0, norm_order=2.0):
    return robust.zeroing_factory(clip, sum_factory.SumFactory(), norm_order)
Пример #7
0
def _zeroed_mean(clip=2.0, norm_order=2.0):
    return robust.zeroing_factory(clip, mean.MeanFactory(), norm_order)