Ejemplo n.º 1
0
 def test_next_not_tff_computation_raises(self):
     with self.assertRaisesRegex(TypeError,
                                 r'Expected .*\.Computation, .*'):
         finalizers.FinalizerProcess(
             initialize_fn=test_initialize_fn,
             next_fn=lambda state, w, u: MeasuredProcessOutput(
                 state, w + u, ()))
Ejemplo n.º 2
0
    def test_two_param_next_raises(self):
        @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE)
        def next_fn(state, weights):
            return MeasuredProcessOutput(state, weights, server_zero())

        with self.assertRaises(errors.TemplateNextFnNumArgsError):
            finalizers.FinalizerProcess(test_initialize_fn, next_fn)
Ejemplo n.º 3
0
    def test_next_return_tuple_raises(self):
        @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                            SERVER_FLOAT)
        def tuple_next_fn(state, weights, update):
            return state, test_finalizer_result(weights, update), server_zero()

        with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError):
            finalizers.FinalizerProcess(test_initialize_fn, tuple_next_fn)
Ejemplo n.º 4
0
    def test_non_server_placed_next_measurements_raises(self):
        @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                            SERVER_FLOAT)
        def next_fn(state, weights, update):
            return MeasuredProcessOutput(
                state, test_finalizer_result(weights, update),
                intrinsics.federated_value(1.0, placements.CLIENTS))

        with self.assertRaises(errors.TemplatePlacementError):
            finalizers.FinalizerProcess(test_initialize_fn, next_fn)
Ejemplo n.º 5
0
    def test_next_return_odict_raises(self):
        @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                            SERVER_FLOAT)
        def odict_next_fn(state, weights, update):
            return collections.OrderedDict(state=state,
                                           result=test_finalizer_result(
                                               weights, update),
                                           measurements=server_zero())

        with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError):
            finalizers.FinalizerProcess(test_initialize_fn, odict_next_fn)
Ejemplo n.º 6
0
    def test_next_return_namedtuple_raises(self):
        measured_process_output = collections.namedtuple(
            'MeasuredProcessOutput', ['state', 'result', 'measurements'])

        @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                            SERVER_FLOAT)
        def namedtuple_next_fn(state, weights, update):
            return measured_process_output(
                state, test_finalizer_result(weights, update), server_zero())

        with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError):
            finalizers.FinalizerProcess(test_initialize_fn, namedtuple_next_fn)
Ejemplo n.º 7
0
    def test_next_state_not_assignable(self):
        @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                            SERVER_FLOAT)
        def float_next_fn(state, weights, update):
            del state
            return MeasuredProcessOutput(
                intrinsics.federated_value(0.0, placements.SERVER),
                test_finalizer_result(weights, update),
                intrinsics.federated_value(1, placements.SERVER))

        with self.assertRaises(errors.TemplateStateNotAssignableError):
            finalizers.FinalizerProcess(test_initialize_fn, float_next_fn)
Ejemplo n.º 8
0
    def test_init_tuple_of_federated_types_raises(self):
        initialize_fn = computations.federated_computation()(
            lambda: (server_zero(), server_zero()))

        @computations.federated_computation(
            initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE,
            SERVER_FLOAT)
        def next_fn(state, weights, update):
            return MeasuredProcessOutput(
                state, test_finalizer_result(weights, update), server_zero())

        with self.assertRaises(errors.TemplateNotFederatedError):
            finalizers.FinalizerProcess(initialize_fn, next_fn)
Ejemplo n.º 9
0
    def test_non_server_placed_next_weight_param_raises(self):
        @computations.federated_computation(SERVER_INT,
                                            computation_types.at_clients(
                                                MODEL_WEIGHTS_TYPE.member),
                                            SERVER_FLOAT)
        def next_fn(state, weights, update):
            return MeasuredProcessOutput(
                state,
                test_finalizer_result(intrinsics.federated_sum(weights),
                                      update), server_zero())

        with self.assertRaises(errors.TemplatePlacementError):
            finalizers.FinalizerProcess(test_initialize_fn, next_fn)
Ejemplo n.º 10
0
    def test_non_federated_init_next_raises(self):
        initialize_fn = computations.tf_computation(lambda: 0)

        @computations.tf_computation(
            tf.int32,
            computation_types.to_type(model_utils.ModelWeights(tf.float32,
                                                               ())),
            tf.float32)
        def next_fn(state, weights, update):
            new_weigths = model_utils.ModelWeights(weights.trainable + update,
                                                   ())
            return MeasuredProcessOutput(state, new_weigths, 0)

        with self.assertRaises(errors.TemplateNotFederatedError):
            finalizers.FinalizerProcess(initialize_fn, next_fn)
Ejemplo n.º 11
0
def test_finalizer():
    @computations.federated_computation(
        empty_init_fn.type_signature.result,
        computation_types.at_server(MODEL_WEIGHTS_TYPE),
        computation_types.at_server(FLOAT_TYPE))
    def next_fn(state, weights, updates):
        new_weights = intrinsics.federated_map(
            computations.tf_computation(lambda x, y: x + y),
            (weights.trainable, updates))
        new_weights = intrinsics.federated_zip(
            model_utils.ModelWeights(new_weights, ()))
        return measured_process.MeasuredProcessOutput(state, new_weights,
                                                      empty_at_server())

    return finalizers.FinalizerProcess(empty_init_fn, next_fn)
Ejemplo n.º 12
0
    def test_result_not_assignable_to_weight_raises(self):
        bad_cast_fn = computations.tf_computation(
            lambda x: tf.nest.map_structure(lambda y: tf.cast(y, tf.float64), x
                                            ))

        @computations.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                            SERVER_FLOAT)
        def next_fn(state, weights, update):
            return MeasuredProcessOutput(
                state,
                intrinsics.federated_map(
                    bad_cast_fn, test_finalizer_result(weights, update)),
                server_zero())

        with self.assertRaises(finalizers.FinalizerResultTypeError):
            finalizers.FinalizerProcess(test_initialize_fn, next_fn)
Ejemplo n.º 13
0
    def test_construction_with_empty_state_does_not_raise(self):
        initialize_fn = computations.federated_computation()(
            lambda: intrinsics.federated_value((), placements.SERVER))

        @computations.federated_computation(
            initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE,
            SERVER_FLOAT)
        def next_fn(state, weights, update):
            return MeasuredProcessOutput(
                state, test_finalizer_result(weights, update),
                intrinsics.federated_value(1, placements.SERVER))

        try:
            finalizers.FinalizerProcess(initialize_fn, next_fn)
        except:  # pylint: disable=bare-except
            self.fail(
                'Could not construct an FinalizerProcess with empty state.')
Ejemplo n.º 14
0
    def test_bad_next_weights_param_type_raises(self):
        bad_model_weights_type = computation_types.at_server(
            computation_types.to_type(
                collections.OrderedDict(trainable=tf.float32,
                                        non_trainable=())))

        @computations.federated_computation(SERVER_INT, bad_model_weights_type,
                                            SERVER_FLOAT)
        def next_fn(state, weights, update):
            return MeasuredProcessOutput(
                state,
                intrinsics.federated_zip(
                    model_utils.ModelWeights(
                        federated_add(weights['trainable'], update), ())),
                server_zero())

        with self.assertRaises(finalizers.ModelWeightsTypeError):
            finalizers.FinalizerProcess(test_initialize_fn, next_fn)
Ejemplo n.º 15
0
 def test_init_param_not_empty_raises(self):
     one_arg_initialize_fn = computations.federated_computation(SERVER_INT)(
         lambda x: x)
     with self.assertRaises(errors.TemplateInitFnParamNotEmptyError):
         finalizers.FinalizerProcess(one_arg_initialize_fn, test_next_fn)
Ejemplo n.º 16
0
 def test_init_state_not_assignable(self):
     float_initialize_fn = computations.federated_computation()(
         lambda: intrinsics.federated_value(0.0, placements.SERVER))
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         finalizers.FinalizerProcess(float_initialize_fn, test_next_fn)
Ejemplo n.º 17
0
 def test_init_not_tff_computation_raises(self):
     with self.assertRaisesRegex(TypeError,
                                 r'Expected .*\.Computation, .*'):
         finalizers.FinalizerProcess(initialize_fn=lambda: 0,
                                     next_fn=test_next_fn)
Ejemplo n.º 18
0
 def test_construction_does_not_raise(self):
     try:
         finalizers.FinalizerProcess(test_initialize_fn, test_next_fn)
     except:  # pylint: disable=bare-except
         self.fail('Could not construct a valid FinalizerProcess.')