def test_next_not_tff_computation_raises(self):
     with self.assertRaisesRegex(TypeError,
                                 r'Expected .*\.Computation, .*'):
         client_works.ClientWorkProcess(
             initialize_fn=test_initialize_fn,
             next_fn=lambda state, w, d: MeasuredProcessOutput(
                 state, w + d, ()))
    def test_next_return_tuple_raises(self):
        @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                            CLIENTS_FLOAT_SEQUENCE)
        def tuple_next_fn(state, weights, data):
            return (state, test_client_result(weights, data), server_zero())

        with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError):
            client_works.ClientWorkProcess(test_initialize_fn, tuple_next_fn)
    def test_two_param_next_raises(self):
        @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE)
        def next_fn(state, weights):
            return MeasuredProcessOutput(state, weights.trainable,
                                         server_zero())

        with self.assertRaises(errors.TemplateNextFnNumArgsError):
            client_works.ClientWorkProcess(test_initialize_fn, next_fn)
    def test_non_server_placed_next_measurements_raises(self):
        @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                            CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(
                state, test_client_result(weights, data),
                intrinsics.federated_value(1.0, placements.CLIENTS))

        with self.assertRaises(errors.TemplatePlacementError):
            client_works.ClientWorkProcess(test_initialize_fn, next_fn)
    def test_next_state_not_assignable(self):
        @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                            CLIENTS_FLOAT_SEQUENCE)
        def float_next_fn(state, weights, data):
            del state
            return MeasuredProcessOutput(
                intrinsics.federated_value(0.0, placements.SERVER),
                test_client_result(weights, data),
                intrinsics.federated_value(1, placements.SERVER))

        with self.assertRaises(errors.TemplateStateNotAssignableError):
            client_works.ClientWorkProcess(test_initialize_fn, float_next_fn)
    def test_non_sequence_next_data_param_raises(self):
        @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                            CLIENTS_FLOAT)
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(
                state,
                intrinsics.federated_zip(
                    client_works.ClientResult(
                        federated_add(weights.trainable, data), client_one())),
                server_zero())

        with self.assertRaises(client_works.ClientDataTypeError):
            client_works.ClientWorkProcess(test_initialize_fn, next_fn)
    def test_non_zipped_next_result_raises(self):
        @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                            CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            reduced_data = intrinsics.federated_map(tf_data_sum, data)
            return MeasuredProcessOutput(
                state,
                client_works.ClientResult(
                    federated_add(weights.trainable, reduced_data),
                    client_one()), server_zero())

        with self.assertRaises(errors.TemplatePlacementError):
            client_works.ClientWorkProcess(test_initialize_fn, next_fn)
    def test_non_clients_placed_next_weights_param_raises(self):
        @computations.federated_computation(SERVER_INT,
                                            computation_types.at_server(
                                                MODEL_WEIGHTS_TYPE.member),
                                            CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(
                state,
                test_client_result(intrinsics.federated_broadcast(weights),
                                   data), server_zero())

        with self.assertRaises(errors.TemplatePlacementError):
            client_works.ClientWorkProcess(test_initialize_fn, next_fn)
    def test_non_server_placed_init_state_raises(self):
        initialize_fn = computations.federated_computation(
            lambda: intrinsics.federated_value(0, placements.CLIENTS))

        @computations.federated_computation(
            initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE,
            CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(state,
                                         test_client_result(weights, data),
                                         server_zero())

        with self.assertRaises(errors.TemplatePlacementError):
            client_works.ClientWorkProcess(initialize_fn, next_fn)
예제 #10
0
    def test_init_tuple_of_federated_types_raises(self):
        initialize_fn = computations.federated_computation()(
            lambda: (server_zero(), server_zero()))

        @computations.federated_computation(
            initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE,
            CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(state,
                                         test_client_result(weights, data),
                                         server_zero())

        with self.assertRaises(errors.TemplateNotFederatedError):
            client_works.ClientWorkProcess(initialize_fn, next_fn)
예제 #11
0
    def test_non_federated_init_next_raises(self):
        initialize_fn = computations.tf_computation(lambda: 0)

        @computations.tf_computation(tf.int32, MODEL_WEIGHTS_TYPE.member,
                                     computation_types.SequenceType(tf.float32)
                                     )
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(
                state,
                client_works.ClientResult(
                    weights.trainable + tf_data_sum(data), ()), ())

        with self.assertRaises(errors.TemplateNotFederatedError):
            client_works.ClientWorkProcess(initialize_fn, next_fn)
예제 #12
0
    def test_next_return_namedtuple_raises(self):
        measured_process_output = collections.namedtuple(
            'MeasuredProcessOutput', ['state', 'result', 'measurements'])

        @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                            CLIENTS_FLOAT_SEQUENCE)
        def namedtuple_next_fn(state, weights, data):
            return measured_process_output(state,
                                           test_client_result(weights, data),
                                           server_zero())

        with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError):
            client_works.ClientWorkProcess(test_initialize_fn,
                                           namedtuple_next_fn)
예제 #13
0
    def test_incorrect_client_result_container_raises(self):
        @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                            CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            reduced_data = intrinsics.federated_map(tf_data_sum, data)
            bad_client_result = intrinsics.federated_zip(
                collections.OrderedDict(update=federated_add(
                    weights.trainable, reduced_data),
                                        update_weight=client_one()))
            return MeasuredProcessOutput(state, bad_client_result,
                                         server_zero())

        with self.assertRaises(client_works.ClientResultTypeError):
            client_works.ClientWorkProcess(test_initialize_fn, next_fn)
예제 #14
0
    def test_non_clients_placed_next_data_param_raises(self):
        server_sequence_float_type = computation_types.at_server(
            computation_types.SequenceType(tf.float32))

        @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                            server_sequence_float_type)
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(
                state,
                test_client_result(weights,
                                   intrinsics.federated_broadcast(data)),
                server_zero())

        with self.assertRaises(errors.TemplatePlacementError):
            client_works.ClientWorkProcess(test_initialize_fn, next_fn)
예제 #15
0
    def test_bad_next_weights_param_type_raises(self):
        bad_model_weights_type = computation_types.at_clients(
            computation_types.to_type(
                collections.OrderedDict(trainable=tf.float32,
                                        non_trainable=())))

        @computations.federated_computation(SERVER_INT, bad_model_weights_type,
                                            CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(state,
                                         test_client_result(weights, data),
                                         server_zero())

        with self.assertRaises(client_works.ModelWeightsTypeError):
            client_works.ClientWorkProcess(test_initialize_fn, next_fn)
예제 #16
0
def test_client_work():
    @computations.tf_computation()
    def make_result(value, data):
        return client_works.ClientResult(update=value.trainable,
                                         update_weight=data.reduce(
                                             0.0, lambda x, y: x + y))

    @computations.federated_computation(
        empty_init_fn.type_signature.result,
        computation_types.at_clients(MODEL_WEIGHTS_TYPE),
        CLIENTS_SEQUENCE_FLOAT_TYPE)
    def next_fn(state, value, client_data):
        result = intrinsics.federated_map(make_result, (value, client_data))
        return measured_process.MeasuredProcessOutput(state, result,
                                                      empty_at_server())

    return client_works.ClientWorkProcess(empty_init_fn, next_fn)
예제 #17
0
    def test_construction_with_empty_state_does_not_raise(self):
        initialize_fn = computations.federated_computation()(
            lambda: intrinsics.federated_value((), placements.SERVER))

        @computations.federated_computation(
            initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE,
            CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            return MeasuredProcessOutput(
                state, test_client_result(weights, data),
                intrinsics.federated_value(1, placements.SERVER))

        try:
            client_works.ClientWorkProcess(initialize_fn, next_fn)
        except:  # pylint: disable=bare-except
            self.fail(
                'Could not construct an ClientWorkProcess with empty state.')
예제 #18
0
    def test_trainable_weights_not_asignable_from_update_raises(self):
        bad_cast_fn = computations.tf_computation(
            lambda x: tf.cast(x, tf.float64))

        @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                            CLIENTS_FLOAT_SEQUENCE)
        def next_fn(state, weights, data):
            reduced_data = intrinsics.federated_map(tf_data_sum, data)
            not_assignable_update = intrinsics.federated_map(
                bad_cast_fn, federated_add(weights.trainable, reduced_data))
            return MeasuredProcessOutput(
                state,
                intrinsics.federated_zip(
                    client_works.ClientResult(not_assignable_update,
                                              client_one())), server_zero())

        with self.assertRaises(client_works.ClientResultTypeError):
            client_works.ClientWorkProcess(test_initialize_fn, next_fn)
예제 #19
0
 def test_init_not_tff_computation_raises(self):
     with self.assertRaisesRegex(TypeError,
                                 r'Expected .*\.Computation, .*'):
         client_works.ClientWorkProcess(initialize_fn=lambda: 0,
                                        next_fn=test_next_fn)
예제 #20
0
 def test_construction_does_not_raise(self):
     try:
         client_works.ClientWorkProcess(test_initialize_fn, test_next_fn)
     except:  # pylint: disable=bare-except
         self.fail('Could not construct a valid ClientWorkProcess.')
예제 #21
0
 def test_init_param_not_empty_raises(self):
     one_arg_initialize_fn = computations.federated_computation(SERVER_INT)(
         lambda x: x)
     with self.assertRaises(errors.TemplateInitFnParamNotEmptyError):
         client_works.ClientWorkProcess(one_arg_initialize_fn, test_next_fn)
예제 #22
0
 def test_init_state_not_assignable(self):
     float_initialize_fn = computations.federated_computation()(
         lambda: intrinsics.federated_value(0.0, placements.SERVER))
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         client_works.ClientWorkProcess(float_initialize_fn, test_next_fn)