예제 #1
0
 def test_unweighted_aggregator_raises(self):
     bad_aggregator = sum_factory.SumFactory().create(FLOAT_TYPE)
     with self.assertRaisesRegex(TypeError, 'weighted'):
         composers.compose_learning_process(test_init_model_weights_fn,
                                            test_distributor(),
                                            test_client_work(),
                                            bad_aggregator,
                                            test_finalizer())
예제 #2
0
 def test_not_finalizer_type_raises(self):
     finalizer = test_finalizer()
     bad_finalizer = measured_process.MeasuredProcess(
         finalizer.initialize, finalizer.next)
     with self.assertRaisesRegex(TypeError, 'FinalizerProcess'):
         composers.compose_learning_process(test_init_model_weights_fn,
                                            test_distributor(),
                                            test_client_work(),
                                            test_aggregator(),
                                            bad_finalizer)
예제 #3
0
    def test_not_tff_computation_init_raises(self):
        def init_model_weights_fn():
            return model_utils.ModelWeights(trainable=tf.constant(1.0),
                                            non_trainable=())

        with self.assertRaisesRegex(TypeError, 'Computation'):
            composers.compose_learning_process(init_model_weights_fn,
                                               test_distributor(),
                                               test_client_work(),
                                               test_aggregator(),
                                               test_finalizer())
예제 #4
0
    def test_one_arg_computation_init_raises(self):
        @computations.tf_computation(computation_types.TensorType(tf.float32))
        def init_model_weights_fn(x):
            return model_utils.ModelWeights(trainable=x, non_trainable=())

        with self.assertRaisesRegex(TypeError, 'Computation'):
            composers.compose_learning_process(init_model_weights_fn,
                                               test_distributor(),
                                               test_client_work(),
                                               test_aggregator(),
                                               test_finalizer())
예제 #5
0
    def test_not_model_weights_init_raises(self):
        @computations.tf_computation()
        def init_model_weights_fn():
            return collections.OrderedDict(trainable=tf.constant(1.0),
                                           non_trainable=())

        with self.assertRaisesRegex(TypeError, 'ModelWeights'):
            composers.compose_learning_process(init_model_weights_fn,
                                               test_distributor(),
                                               test_client_work(),
                                               test_aggregator(),
                                               test_finalizer())
예제 #6
0
    def test_federated_init_raises(self):
        @computations.federated_computation()
        def init_model_weights_fn():
            return intrinsics.federated_eval(test_init_model_weights_fn,
                                             placements.SERVER)

        with self.assertRaisesRegex(TypeError, 'unplaced'):
            composers.compose_learning_process(init_model_weights_fn,
                                               test_distributor(),
                                               test_client_work(),
                                               test_aggregator(),
                                               test_finalizer())
예제 #7
0
    def test_learning_process_composes(self):
        process = composers.compose_learning_process(
            test_init_model_weights_fn, test_distributor(), test_client_work(),
            test_aggregator(), test_finalizer())

        self.assertIsInstance(process, learning_process.LearningProcess)
        self.assertEqual(
            process.initialize.type_signature.result.member.python_container,
            composers.LearningAlgorithmState)
        self.assertEqual(
            process.initialize.type_signature.result.member.
            global_model_weights, MODEL_WEIGHTS_TYPE)

        # Reported metrics have the expected fields.
        metrics_type = process.next.type_signature.result.metrics.member
        self.assertTrue(structure.has_field(metrics_type, 'distributor'))
        self.assertTrue(structure.has_field(metrics_type, 'client_work'))
        self.assertTrue(structure.has_field(metrics_type, 'aggregator'))
        self.assertTrue(structure.has_field(metrics_type, 'finalizer'))
        self.assertLen(metrics_type, 4)