Пример #1
0
    def test_orchestration_type_signature(self):
        iterative_process = optimizer_utils.build_model_delta_optimizer_process(
            model_fn=model_examples.TrainableLinearRegression,
            model_to_client_delta_fn=DummyClientDeltaFn,
            server_optimizer_fn=lambda: gradient_descent.SGD(learning_rate=1.0
                                                             ))

        expected_model_weights_type = model_utils.ModelWeights(
            collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])),
                                     ('b', tf.float32)]),
            collections.OrderedDict([('c', tf.float32)]))

        # ServerState consists of a model and optimizer_state. The optimizer_state
        # is provided by TensorFlow, TFF doesn't care what the actual value is.
        expected_federated_server_state_type = tff.FederatedType(
            optimizer_utils.ServerState(expected_model_weights_type,
                                        test.AnyType(), test.AnyType(),
                                        test.AnyType()),
            placement=tff.SERVER,
            all_equal=True)

        expected_federated_dataset_type = tff.FederatedType(tff.SequenceType(
            model_examples.TrainableLinearRegression().input_spec),
                                                            tff.CLIENTS,
                                                            all_equal=False)

        expected_model_output_types = tff.FederatedType(
            collections.OrderedDict([
                ('loss', tff.TensorType(tf.float32, [])),
                ('num_examples', tff.TensorType(tf.int32, [])),
            ]),
            tff.SERVER,
            all_equal=True)

        # `initialize` is expected to be a funcion of no arguments to a ServerState.
        self.assertEqual(
            tff.FunctionType(parameter=None,
                             result=expected_federated_server_state_type),
            iterative_process.initialize.type_signature)

        # `next` is expected be a function of (ServerState, Datasets) to
        # ServerState.
        self.assertEqual(
            tff.FunctionType(parameter=[
                expected_federated_server_state_type,
                expected_federated_dataset_type
            ],
                             result=(expected_federated_server_state_type,
                                     expected_model_output_types)),
            iterative_process.next.type_signature)
Пример #2
0
  def test_orchestration_typecheck(self):
    iterative_process = federated_sgd.build_federated_sgd_process(
        model_fn=model_examples.LinearRegression)

    expected_model_weights_type = model_utils.ModelWeights(
        collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])),
                                 ('b', tf.float32)]),
        collections.OrderedDict([('c', tf.float32)]))

    # ServerState consists of a model and optimizer_state. The optimizer_state
    # is provided by TensorFlow, TFF doesn't care what the actual value is.
    expected_federated_server_state_type = tff.FederatedType(
        optimizer_utils.ServerState(expected_model_weights_type,
                                    test.AnyType()),
        placement=tff.SERVER,
        all_equal=True)

    expected_federated_dataset_type = tff.FederatedType(
        tff.SequenceType(
            model_examples.LinearRegression.make_batch(
                tff.TensorType(tf.float32, [None, 2]),
                tff.TensorType(tf.float32, [None, 1]))),
        tff.CLIENTS,
        all_equal=False)

    expected_model_output_types = tff.FederatedType(
        collections.OrderedDict([
            ('loss', tff.TensorType(tf.float32, [])),
            ('num_examples', tff.TensorType(tf.int32, [])),
        ]),
        tff.SERVER,
        all_equal=True)

    # `initialize` is expected to be a funcion of no arguments to a ServerState.
    self.assertEqual(
        tff.FunctionType(
            parameter=None, result=expected_federated_server_state_type),
        iterative_process.initialize.type_signature)

    # `next` is expected be a function of (ServerState, Datasets) to
    # ServerState.
    self.assertEqual(
        tff.FunctionType(
            parameter=[
                expected_federated_server_state_type,
                expected_federated_dataset_type
            ],
            result=(expected_federated_server_state_type,
                    expected_model_output_types)),
        iterative_process.next.type_signature)