コード例 #1
0
    def test_clip_type_properties_simple(self, value_type):
        factory = _clipped_sum()
        value_type = tff.to_type(value_type)
        process = factory.create(value_type)

        self.assertIsInstance(process, tff.templates.AggregationProcess)

        server_state_type = tff.type_at_server(
            ())  # Inner SumFactory has no state

        expected_initialize_type = tff.FunctionType(parameter=None,
                                                    result=server_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_measurements_type = tff.type_at_server(
            collections.OrderedDict(agg_process=()))

        expected_next_type = tff.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=tff.type_at_clients(value_type)),
            result=tff.templates.MeasuredProcessOutput(
                state=server_state_type,
                result=tff.type_at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
コード例 #2
0
    def test_iterative_process_type_signature(self):
        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        lr_callback_type = tff.framework.type_from_tensors(client_lr_callback)

        server_state_type = tff.FederatedType(
            adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights(
                trainable=(tff.TensorType(tf.float32, [1, 1]),
                           tff.TensorType(tf.float32, [1])),
                non_trainable=()),
                                         optimizer_state=[tf.int64],
                                         client_lr_callback=lr_callback_type,
                                         server_lr_callback=lr_callback_type),
            tff.SERVER)

        self.assertEqual(
            iterative_process.initialize.type_signature,
            tff.FunctionType(parameter=None, result=server_state_type))

        dataset_type = tff.FederatedType(
            tff.SequenceType(
                collections.OrderedDict(
                    x=tff.TensorType(tf.float32, [None, 1]),
                    y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS)

        metrics_type = tff.FederatedType(
            collections.OrderedDict(loss=tff.TensorType(tf.float32)),
            tff.SERVER)
        output_type = collections.OrderedDict(before_training=metrics_type,
                                              during_training=metrics_type)

        expected_result_type = (server_state_type, output_type)
        expected_type = tff.FunctionType(parameter=collections.OrderedDict(
            server_state=server_state_type, federated_dataset=dataset_type),
                                         result=expected_result_type)

        actual_type = iterative_process.next.type_signature
        self.assertEqual(actual_type,
                         expected_type,
                         msg='{s}\n!={t}'.format(s=actual_type,
                                                 t=expected_type))
コード例 #3
0
    def test_iterative_process_type_signature(self):
        iterative_process = decay_iterative_process_builder.from_flags(
            input_spec=get_input_spec(),
            model_builder=model_builder,
            loss_builder=loss_builder,
            metrics_builder=metrics_builder)

        dummy_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=FLAGS.client_learning_rate,
            decay_factor=FLAGS.client_decay_factor,
            min_delta=FLAGS.min_delta,
            min_lr=FLAGS.min_lr,
            window_size=FLAGS.window_size,
            patience=FLAGS.patience)
        lr_callback_type = tff.framework.type_from_tensors(dummy_lr_callback)

        server_state_type = tff.FederatedType(
            adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights(
                trainable=(tff.TensorType(tf.float32, [1, 1]),
                           tff.TensorType(tf.float32, [1])),
                non_trainable=()),
                                         optimizer_state=[tf.int64],
                                         client_lr_callback=lr_callback_type,
                                         server_lr_callback=lr_callback_type),
            tff.SERVER)

        self.assertEqual(
            iterative_process.initialize.type_signature,
            tff.FunctionType(parameter=None, result=server_state_type))

        dataset_type = tff.FederatedType(
            tff.SequenceType(
                collections.OrderedDict(
                    x=tff.TensorType(tf.float32, [None, 1]),
                    y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS)

        metrics_type = tff.FederatedType(
            collections.OrderedDict(
                mean_squared_error=tff.TensorType(tf.float32),
                loss=tff.TensorType(tf.float32)), tff.SERVER)
        output_type = collections.OrderedDict(before_training=metrics_type,
                                              during_training=metrics_type)

        expected_result_type = (server_state_type, output_type)
        expected_type = tff.FunctionType(parameter=(server_state_type,
                                                    dataset_type),
                                         result=expected_result_type)

        actual_type = iterative_process.next.type_signature
        self.assertTrue(actual_type.is_equivalent_to(expected_type))
コード例 #4
0
    def test_eval_fn_has_correct_type_signature(self):
        metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()]
        eval_fn = evaluation.build_centralized_evaluation(
            tff_model_fn, metrics_builder)
        actual_type = eval_fn.type_signature

        model_type = tff.FederatedType(
            tff.learning.ModelWeights(
                trainable=(
                    tff.TensorType(tf.float32, [1, 1]),
                    tff.TensorType(tf.float32, [1]),
                ),
                non_trainable=(),
            ), tff.SERVER)

        dataset_type = tff.FederatedType(
            tff.SequenceType(
                collections.OrderedDict(
                    x=tff.TensorType(tf.float32, [None, 1]),
                    y=tff.TensorType(tf.float32, [None, 1]))), tff.SERVER)

        metrics_type = tff.FederatedType(
            collections.OrderedDict(
                mean_squared_error=tff.TensorType(tf.float32),
                num_examples=tff.TensorType(tf.float32)), tff.SERVER)

        expected_type = tff.FunctionType(parameter=collections.OrderedDict(
            model_weights=model_type, centralized_dataset=dataset_type),
                                         result=metrics_type)

        actual_type.check_assignable_from(expected_type)
コード例 #5
0
  def test_federated_mean_masked(self):

    value_type = tff.StructType([('a', tff.TensorType(tf.float32, shape=[3])),
                                 ('b', tff.TensorType(tf.float32, shape=[2,
                                                                         1]))])
    weight_type = tff.TensorType(tf.float32)

    federated_mean_masked_fn = fed_pa_schedule.build_federated_mean_masked(
        value_type, weight_type)

    # Check type signature.
    expected_type = tff.FunctionType(
        parameter=collections.OrderedDict(
            value=tff.FederatedType(value_type, tff.CLIENTS),
            weight=tff.FederatedType(weight_type, tff.CLIENTS)),
        result=tff.FederatedType(value_type, tff.SERVER))
    self.assertTrue(
        federated_mean_masked_fn.type_signature.is_equivalent_to(expected_type),
        msg='{s}\n!={t}'.format(
            s=federated_mean_masked_fn.type_signature, t=expected_type))

    # Check correctness of zero masking in the mean.
    values = [
        collections.OrderedDict(
            a=tf.constant([0.0, 1.0, 1.0]), b=tf.constant([[1.0], [2.0]])),
        collections.OrderedDict(
            a=tf.constant([1.0, 3.0, 0.0]), b=tf.constant([[3.0], [0.0]]))
    ]
    weights = [tf.constant(1.0), tf.constant(3.0)]
    output = federated_mean_masked_fn(values, weights)
    expected_output = collections.OrderedDict(
        a=tf.constant([1.0, 2.5, 1.0]), b=tf.constant([[2.5], [2.0]]))
    self.assertAllClose(output, expected_output)
コード例 #6
0
 def test_types(self):
     trainer = create_trainer(batch_size=100, step_size=0.01)
     model_type = trainer.create_initial_model.type_signature.result
     example_batch = next(trainer.generate_random_batches(1))
     make_example_batch = tff.experimental.jax_computation(
         lambda: example_batch)
     batch_type = make_example_batch.type_signature.result
     self.assertEqual(
         str(trainer.train_on_one_batch.type_signature),
         str(
             tff.FunctionType(
                 collections.OrderedDict([('model', model_type),
                                          ('batch', batch_type)]),
                 model_type)))
     self.assertEqual(
         str(trainer.compute_loss_on_one_batch.type_signature),
         str(
             tff.FunctionType(
                 collections.OrderedDict([('model', model_type),
                                          ('batch', batch_type)]),
                 np.float32)))
コード例 #7
0
  def test_build_with_preprocess_function(self):
    test_dataset = tf.data.Dataset.range(5)
    client_datasets_type = tff.FederatedType(
        tff.SequenceType(test_dataset.element_spec), tff.CLIENTS)

    @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
    def preprocess_dataset(ds):

      def to_batch(x):
        return _Batch(
            tf.fill(dims=(784,), value=float(x) * 2.0),
            tf.expand_dims(tf.cast(x + 1, dtype=tf.int64), axis=0))

      return ds.map(to_batch).batch(2)

    iterproc = _build_simple_fed_pa_process(
        _uncompiled_model_builder,
        client_optimizer_fn=tf.keras.optimizers.SGD,
        server_optimizer_fn=tf.keras.optimizers.SGD)
    iterproc = tff.simulation.compose_dataset_computation_with_iterative_process(
        preprocess_dataset, iterproc)

    with tf.Graph().as_default():
      test_model_for_types = _uncompiled_model_builder()

    server_state_type = tff.FederatedType(
        fed_pa_schedule.ServerState(
            model=tff.framework.type_from_tensors(
                tff.learning.ModelWeights(
                    test_model_for_types.trainable_variables,
                    test_model_for_types.non_trainable_variables)),
            optimizer_state=(tf.int64,),
            round_num=tf.float32), tff.SERVER)
    metrics_type = tff.FederatedType(
        tff.StructType([('loss', tf.float32),
                        ('model_delta_zeros_percent', tf.float32),
                        ('model_delta_correction_l2_norm', tf.float32)]),
        tff.SERVER)

    expected_parameter_type = collections.OrderedDict(
        server_state=server_state_type,
        federated_dataset=client_datasets_type,
    )
    expected_result_type = (server_state_type, metrics_type)

    expected_type = tff.FunctionType(
        parameter=expected_parameter_type, result=expected_result_type)
    self.assertTrue(
        iterproc.next.type_signature.is_equivalent_to(expected_type),
        msg='{s}\n!={t}'.format(
            s=iterproc.next.type_signature, t=expected_type))
コード例 #8
0
    def test_build_with_preprocess_function(self):
        test_dataset = tf.data.Dataset.range(5)
        client_datasets_type = tff.type_at_clients(
            tff.SequenceType(test_dataset.element_spec))

        @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
        def preprocess_dataset(ds):
            def create_batch(x):
                return collections.OrderedDict(
                    x=[tf.cast(x, dtype=tf.float32)], y=[2.0])

            return ds.map(create_batch).batch(2)

        iterproc = fed_avg_schedule.build_fed_avg_process(
            model_builder,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            client_lr=0.01,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        iterproc = tff.simulation.compose_dataset_computation_with_iterative_process(
            preprocess_dataset, iterproc)

        with tf.Graph().as_default():
            test_model_for_types = model_builder()

        server_state_type = tff.FederatedType(
            fed_avg_schedule.ServerState(model=tff.framework.type_from_tensors(
                tff.learning.ModelWeights(
                    test_model_for_types.trainable_variables,
                    test_model_for_types.non_trainable_variables)),
                                         optimizer_state=(tf.int64, ),
                                         round_num=tf.float32), tff.SERVER)
        metrics_type = test_model_for_types.federated_output_computation.type_signature.result

        expected_parameter_type = collections.OrderedDict(
            server_state=server_state_type,
            federated_dataset=client_datasets_type,
        )
        expected_result_type = (server_state_type, metrics_type)

        expected_type = tff.FunctionType(parameter=expected_parameter_type,
                                         result=expected_result_type)
        self.assertTrue(
            iterproc.next.type_signature.is_equivalent_to(expected_type),
            msg='{s}\n!={t}'.format(s=iterproc.next.type_signature,
                                    t=expected_type))
コード例 #9
0
    def test_clip_by_global_norm(self):
        clip_norm = 20.0
        test_deltas = [
            create_weights_delta(),
            create_weights_delta(constant=10)
        ]
        update_type = tff.framework.type_from_tensors(test_deltas[0])
        aggregate_fn = aggregate_fns.build_fixed_clip_norm_mean_process(
            clip_norm=clip_norm, model_update_type=update_type)

        self.assertEqual(
            aggregate_fn.next.type_signature,
            tff.FunctionType(
                parameter=collections.OrderedDict(
                    state=tff.FederatedType((), tff.SERVER),
                    deltas=tff.FederatedType(update_type, tff.CLIENTS),
                    weights=tff.FederatedType(tf.float32, tff.CLIENTS),
                ),
                result=collections.OrderedDict(
                    state=tff.FederatedType((), tff.SERVER),
                    result=tff.FederatedType(update_type, tff.SERVER),
                    measurements=tff.FederatedType(
                        aggregate_fns.NormClippedAggregationMetrics(
                            max_global_norm=tf.float32, num_clipped=tf.int32),
                        tff.SERVER)),
            ))

        state = aggregate_fn.initialize()
        weights = [1., 1.]
        output = aggregate_fn.next(state, test_deltas, weights)

        expected_clipped = []
        for delta in test_deltas:
            clipped, _ = tf.clip_by_global_norm(tf.nest.flatten(delta),
                                                clip_norm)
            expected_clipped.append(tf.nest.pack_sequence_as(delta, clipped))
        expected_mean = tf.nest.map_structure(lambda a, b: (a + b) / 2,
                                              *expected_clipped)
        self.assertAllClose(expected_mean, output['result'])

        # Global l2 norms [17.74824, 53.99074].
        metrics = output['measurements']
        self.assertAlmostEqual(metrics.max_global_norm, 53.99074, places=5)
        self.assertEqual(metrics.num_clipped, 1)
コード例 #10
0
  def test_build_with_preprocess_function(self):
    test_dataset = tf.data.Dataset.range(5)
    client_datasets_type = tff.FederatedType(
        tff.SequenceType(test_dataset.element_spec), tff.CLIENTS)

    @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
    def preprocess_dataset(ds):

      def to_batch(x):
        return _Batch(
            tf.fill(dims=(784,), value=float(x) * 2.0),
            tf.expand_dims(tf.cast(x + 1, dtype=tf.int64), axis=0))

      return ds.map(to_batch).batch(2)

    iterproc_adapter = fed_avg_schedule.build_fed_avg_process(
        _uncompiled_model_builder,
        client_optimizer_fn=tf.keras.optimizers.SGD,
        server_optimizer_fn=tf.keras.optimizers.SGD,
        dataset_preprocess_comp=preprocess_dataset)

    with tf.Graph().as_default():
      test_model_for_types = _uncompiled_model_builder()

    iterproc = iterproc_adapter._iterative_process

    server_state_type = tff.FederatedType(
        fed_avg_schedule.ServerState(
            model=tff.framework.type_from_tensors(
                fed_avg_schedule.ModelWeights(
                    test_model_for_types.trainable_variables,
                    test_model_for_types.non_trainable_variables)),
            optimizer_state=(tf.int64,),
            round_num=tf.float32), tff.SERVER)
    metrics_type = test_model_for_types.federated_output_computation.type_signature.result

    expected_type = tff.FunctionType(
        parameter=(server_state_type, client_datasets_type),
        result=(server_state_type, metrics_type))
    self.assertEqual(
        iterproc.next.type_signature,
        expected_type,
        msg='{s}\n!={t}'.format(
            s=iterproc.next.type_signature, t=expected_type))
コード例 #11
0
    def test_build_with_preprocess_funtion(self):
        test_dataset = tf.data.Dataset.range(5)
        client_datasets_type = tff.FederatedType(
            tff.SequenceType(test_dataset.element_spec), tff.CLIENTS)

        @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
        def preprocess_dataset(ds):
            def to_batch(x):
                return collections.OrderedDict(x=[float(x) * 1.0],
                                               y=[float(x) * 3.0 + 1.0])

            return ds.map(to_batch).repeat().batch(2).take(3)

        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            dataset_preprocess_comp=preprocess_dataset)

        lr_callback_type = tff.framework.type_from_tensors(client_lr_callback)

        server_state_type = tff.FederatedType(
            adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights(
                trainable=(tff.TensorType(tf.float32, [1, 1]),
                           tff.TensorType(tf.float32, [1])),
                non_trainable=()),
                                         optimizer_state=[tf.int64],
                                         client_lr_callback=lr_callback_type,
                                         server_lr_callback=lr_callback_type),
            tff.SERVER)

        self.assertEqual(
            iterative_process.initialize.type_signature,
            tff.FunctionType(parameter=None, result=server_state_type))

        metrics_type = tff.FederatedType(
            collections.OrderedDict(loss=tff.TensorType(tf.float32)),
            tff.SERVER)
        output_type = collections.OrderedDict(before_training=metrics_type,
                                              during_training=metrics_type)
        expected_result_type = (server_state_type, output_type)

        expected_type = tff.FunctionType(parameter=collections.OrderedDict(
            server_state=server_state_type,
            federated_dataset=client_datasets_type),
                                         result=expected_result_type)

        actual_type = iterative_process.next.type_signature
        self.assertEqual(actual_type,
                         expected_type,
                         msg='{s}\n!={t}'.format(s=actual_type,
                                                 t=expected_type))