def test_execution_stateful_optimizer(self):
        client_work_process = client_works.build_model_delta_client_work(
            model_examples.LinearRegression, sgdm.build_sgdm(0.1,
                                                             momentum=0.9))
        data = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict(
                x=[[1.0, 2.0], [3.0, 4.0]],
                y=[[5.0], [6.0]],
            )).batch(2)
        data = [data, data.repeat(2)]  # 1st client has 2 examples, 2nd has 4.
        model_weights = model_utils.ModelWeights(trainable=[[[0.0], [0.0]],
                                                            0.0],
                                                 non_trainable=[0.0])
        client_model_weights = [model_weights] * 2

        state = client_work_process.initialize()
        output = client_work_process.next(state, client_model_weights, data)

        expected_result = (
            client_works.ClientResult([[[-1.15], [-1.7]], -0.55], 2.0),
            client_works.ClientResult([[[-1.46], [-2.26]], -0.8], 4.0),
        )

        self.assertEqual((), output.state)
        for i in range(len(expected_result)):
            self.assertAllClose(expected_result[i].update,
                                output.result[i].update)
            self.assertAllClose(expected_result[i].update_weight,
                                output.result[i].update_weight)
        self.assertEqual((), output.measurements)
 def next_fn(state, weights, data):
     reduced_data = intrinsics.federated_map(tf_data_sum, data)
     return MeasuredProcessOutput(
         state,
         client_works.ClientResult(
             federated_add(weights.trainable, reduced_data),
             client_one()), server_zero())
 def next_fn(state, weights, data):
     return MeasuredProcessOutput(
         state,
         intrinsics.federated_zip(
             client_works.ClientResult(
                 federated_add(weights.trainable, data), client_one())),
         server_zero())
 def next_fn(state, weights, data):
     reduced_data = intrinsics.federated_map(tf_data_sum, data)
     not_assignable_update = intrinsics.federated_map(
         bad_cast_fn, federated_add(weights.trainable, reduced_data))
     return MeasuredProcessOutput(
         state,
         intrinsics.federated_zip(
             client_works.ClientResult(not_assignable_update,
                                       client_one())), server_zero())
    def test_type_properties(self):
        model_fn = model_examples.LinearRegression
        client_work_process = client_works.build_model_delta_client_work(
            model_fn, sgdm.build_sgdm(1.0))
        self.assertIsInstance(client_work_process,
                              client_works.ClientWorkProcess)

        mw_type = model_utils.ModelWeights(
            trainable=computation_types.to_type([(tf.float32, (2, 1)),
                                                 tf.float32]),
            non_trainable=computation_types.to_type([tf.float32]))
        expected_param_model_weights_type = computation_types.at_clients(
            mw_type)
        expected_param_data_type = computation_types.at_clients(
            computation_types.SequenceType(
                computation_types.to_type(model_fn().input_spec)))
        expected_result_type = computation_types.at_clients(
            client_works.ClientResult(
                update=mw_type.trainable,
                update_weight=computation_types.TensorType(tf.float32)))
        expected_state_type = computation_types.at_server(())
        expected_measurements_type = computation_types.at_server(())

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        expected_initialize_type.check_equivalent_to(
            client_work_process.initialize.type_signature)

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                weights=expected_param_model_weights_type,
                client_data=expected_param_data_type),
            result=MeasuredProcessOutput(expected_state_type,
                                         expected_result_type,
                                         expected_measurements_type))
        expected_next_type.check_equivalent_to(
            client_work_process.next.type_signature)
Beispiel #6
0
 def make_result(value, data):
     return client_works.ClientResult(update=value.trainable,
                                      update_weight=data.reduce(
                                          0.0, lambda x, y: x + y))
def test_client_result(weights, data):
    reduced_data = intrinsics.federated_map(tf_data_sum, data)
    return intrinsics.federated_zip(
        client_works.ClientResult(update=federated_add(weights.trainable,
                                                       reduced_data),
                                  update_weight=client_one()))
 def next_fn(state, weights, data):
     return MeasuredProcessOutput(
         state,
         client_works.ClientResult(
             weights.trainable + tf_data_sum(data), ()), ())