def test_unweighted_aggregator_raises(self): bad_aggregator = sum_factory.SumFactory().create(FLOAT_TYPE) with self.assertRaisesRegex(TypeError, 'weighted'): composers.compose_learning_process(test_init_model_weights_fn, test_distributor(), test_client_work(), bad_aggregator, test_finalizer())
def test_not_finalizer_type_raises(self): finalizer = test_finalizer() bad_finalizer = measured_process.MeasuredProcess( finalizer.initialize, finalizer.next) with self.assertRaisesRegex(TypeError, 'FinalizerProcess'): composers.compose_learning_process(test_init_model_weights_fn, test_distributor(), test_client_work(), test_aggregator(), bad_finalizer)
def test_not_tff_computation_init_raises(self): def init_model_weights_fn(): return model_utils.ModelWeights(trainable=tf.constant(1.0), non_trainable=()) with self.assertRaisesRegex(TypeError, 'Computation'): composers.compose_learning_process(init_model_weights_fn, test_distributor(), test_client_work(), test_aggregator(), test_finalizer())
def test_one_arg_computation_init_raises(self): @computations.tf_computation(computation_types.TensorType(tf.float32)) def init_model_weights_fn(x): return model_utils.ModelWeights(trainable=x, non_trainable=()) with self.assertRaisesRegex(TypeError, 'Computation'): composers.compose_learning_process(init_model_weights_fn, test_distributor(), test_client_work(), test_aggregator(), test_finalizer())
def test_not_model_weights_init_raises(self): @computations.tf_computation() def init_model_weights_fn(): return collections.OrderedDict(trainable=tf.constant(1.0), non_trainable=()) with self.assertRaisesRegex(TypeError, 'ModelWeights'): composers.compose_learning_process(init_model_weights_fn, test_distributor(), test_client_work(), test_aggregator(), test_finalizer())
def test_federated_init_raises(self): @computations.federated_computation() def init_model_weights_fn(): return intrinsics.federated_eval(test_init_model_weights_fn, placements.SERVER) with self.assertRaisesRegex(TypeError, 'unplaced'): composers.compose_learning_process(init_model_weights_fn, test_distributor(), test_client_work(), test_aggregator(), test_finalizer())
def test_learning_process_composes(self): process = composers.compose_learning_process( test_init_model_weights_fn, test_distributor(), test_client_work(), test_aggregator(), test_finalizer()) self.assertIsInstance(process, learning_process.LearningProcess) self.assertEqual( process.initialize.type_signature.result.member.python_container, composers.LearningAlgorithmState) self.assertEqual( process.initialize.type_signature.result.member. global_model_weights, MODEL_WEIGHTS_TYPE) # Reported metrics have the expected fields. metrics_type = process.next.type_signature.result.metrics.member self.assertTrue(structure.has_field(metrics_type, 'distributor')) self.assertTrue(structure.has_field(metrics_type, 'client_work')) self.assertTrue(structure.has_field(metrics_type, 'aggregator')) self.assertTrue(structure.has_field(metrics_type, 'finalizer')) self.assertLen(metrics_type, 4)