def test_execute_with_preprocess_function(self):
        test_dataset = tf.data.Dataset.range(1)

        @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
        def preprocess_dataset(ds):
            def to_example(x):
                del x  # Unused.
                return _Batch(x=np.ones([784], dtype=np.float32),
                              y=np.ones([1], dtype=np.int64))

            return ds.map(to_example).batch(1)

        iterproc = fed_avg_schedule.build_fed_avg_process(
            _uncompiled_model_builder,
            TAU,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        iterproc = tff.simulation.compose_dataset_computation_with_iterative_process(
            preprocess_dataset, iterproc)

        _, train_outputs, _ = self._run_rounds(iterproc, [test_dataset], 6)
        self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
        train_gap_first_half = train_outputs[0]['loss'] - train_outputs[2][
            'loss']
        train_gap_second_half = train_outputs[3]['loss'] - train_outputs[5][
            'loss']
        self.assertLess(train_gap_second_half, train_gap_first_half)
Beispiel #2
0
    def iterative_process_builder(
        model_fn: Callable[[], tff.learning.Model]
    ) -> tff.templates.IterativeProcess:
        """Creates an iterative process using a given TFF `model_fn`.

    Args:
      model_fn: A no-arg function returning a `tff.learning.Model`.

    Returns:
      A `tff.templates.IterativeProcess`.
    """
        if FLAGS.task == 'shakespeare' or FLAGS.task == 'stackoverflow_nwp':

            def client_weight_fn(local_outputs):
                return tf.cast(tf.squeeze(local_outputs['num_tokens']),
                               tf.float32)
        else:
            client_weight_fn = None

        return fed_avg_schedule.build_fed_avg_process(
            model_fn=model_fn,
            tau=FLAGS.tau,
            client_optimizer_fn=client_optimizer_fn,
            client_lr=client_lr_schedule,
            server_optimizer_fn=server_optimizer_fn,
            server_lr=server_lr_schedule,
            client_weight_fn=client_weight_fn)
    def test_fed_avg_without_schedule_decreases_loss(self):
        federated_data = [[_batch_fn()]]

        iterproc = fed_avg_schedule.build_fed_avg_process(
            _uncompiled_model_builder,
            TAU,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 5)
        self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
    def test_server_update_with_nan_data_is_noop(self):
        federated_data = [[_batch_fn(has_nan=True)]]

        iterproc = fed_avg_schedule.build_fed_avg_process(
            _uncompiled_model_builder,
            TAU,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        state, _, initial_state = self._run_rounds(iterproc, federated_data, 1)
        self.assertAllClose(state.model.trainable,
                            initial_state.model.trainable, 1e-8)
        self.assertAllClose(state.model.non_trainable,
                            initial_state.model.non_trainable, 1e-8)
    def test_fed_avg_with_custom_client_weight_fn(self):
        federated_data = [[_batch_fn()]]

        def client_weight_fn(local_outputs):
            return 1.0 / (1.0 + local_outputs['loss'][-1])

        iterproc = fed_avg_schedule.build_fed_avg_process(
            _uncompiled_model_builder,
            TAU,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            client_weight_fn=client_weight_fn)

        _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 5)
        self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
    def test_build_with_preprocess_function(self):
        test_dataset = tf.data.Dataset.range(5)
        client_datasets_type = tff.type_at_clients(
            tff.SequenceType(test_dataset.element_spec))

        @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
        def preprocess_dataset(ds):
            def to_batch(x):
                return _Batch(
                    tf.fill(dims=(784, ), value=float(x) * 2.0),
                    tf.expand_dims(tf.cast(x + 1, dtype=tf.int64), axis=0))

            return ds.map(to_batch).batch(2)

        iterproc = fed_avg_schedule.build_fed_avg_process(
            _uncompiled_model_builder,
            TAU,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        iterproc = tff.simulation.compose_dataset_computation_with_iterative_process(
            preprocess_dataset, iterproc)

        with tf.Graph().as_default():
            test_model_for_types = _uncompiled_model_builder()

        server_state_type = tff.FederatedType(
            fed_avg_schedule.ServerState(model=tff.framework.type_from_tensors(
                tff.learning.ModelWeights(
                    test_model_for_types.trainable_variables,
                    test_model_for_types.non_trainable_variables)),
                                         optimizer_state=(tf.int64, ),
                                         round_num=tf.float32,
                                         client_drift=tf.float32), tff.SERVER)
        metrics_type = test_model_for_types.federated_output_computation.type_signature.result

        expected_parameter_type = collections.OrderedDict(
            server_state=server_state_type,
            federated_dataset=client_datasets_type,
        )
        expected_result_type = (server_state_type, metrics_type)

        expected_type = tff.FunctionType(parameter=expected_parameter_type,
                                         result=expected_result_type)
        self.assertTrue(
            iterproc.next.type_signature.is_equivalent_to(expected_type),
            msg='{s}\n!={t}'.format(s=iterproc.next.type_signature,
                                    t=expected_type))
    def test_fed_avg_with_client_and_server_schedules(self):
        federated_data = [[_batch_fn()]]

        iterproc = fed_avg_schedule.build_fed_avg_process(
            _uncompiled_model_builder,
            TAU,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            client_lr=lambda x: 0.1 / (x + 1)**2,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            server_lr=lambda x: 1.0 / (x + 1)**2)

        _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 6)
        self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
        train_gap_first_half = train_outputs[0]['loss'] - train_outputs[2][
            'loss']
        train_gap_second_half = train_outputs[3]['loss'] - train_outputs[5][
            'loss']
        self.assertLess(train_gap_second_half, train_gap_first_half)
    def test_fed_avg_with_server_schedule(self):
        federated_data = [[_batch_fn()]]

        @tf.function
        def lr_schedule(x):
            return 1.0 if x < 1.5 else 0.0

        iterproc = fed_avg_schedule.build_fed_avg_process(
            _uncompiled_model_builder,
            TAU,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD,
            server_lr=lr_schedule)

        _, train_outputs, _ = self._run_rounds(iterproc, federated_data, 4)
        self.assertLess(train_outputs[1]['loss'], train_outputs[0]['loss'])
        self.assertNear(train_outputs[2]['loss'],
                        train_outputs[3]['loss'],
                        err=1e-4)
    def test_get_model_weights(self):
        federated_data = [[_batch_fn()]]

        iterative_process = fed_avg_schedule.build_fed_avg_process(
            _uncompiled_model_builder,
            TAU,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)
        state = iterative_process.initialize()

        self.assertIsInstance(iterative_process.get_model_weights(state),
                              tff.learning.ModelWeights)
        self.assertAllClose(
            state.model.trainable,
            iterative_process.get_model_weights(state).trainable)

        for _ in range(3):
            state, _ = iterative_process.next(state, federated_data)
            self.assertIsInstance(iterative_process.get_model_weights(state),
                                  tff.learning.ModelWeights)
            self.assertAllClose(
                state.model.trainable,
                iterative_process.get_model_weights(state).trainable)