def test_intrinsic_construction_fails_bad_type(self): x = computation_building_blocks.Reference('x', tf.int32) bad_lambda = computation_building_blocks.Lambda('x', tf.int32, x) x = computation_building_blocks.Reference('x', tf.int32) clients_ref = computation_building_blocks.Reference( 'y', computation_types.FederatedType(tf.float32, placement_literals.CLIENTS)) with self.assertRaises(TypeError): computation_constructing_utils.construct_map_or_apply( bad_lambda, clients_ref)
def test_intrinsic_construction_clients(self): federated_comp = computation_building_blocks.Reference( 'test', computation_types.FederatedType([('a', tf.int32), ('b', tf.bool)], placement_literals.CLIENTS, True)) arg_ref = computation_building_blocks.Reference('x', [('a', tf.int32), ('b', tf.bool)]) return_val = computation_building_blocks.Selection(arg_ref, name='a') non_federated_fn = computation_building_blocks.Lambda( 'x', arg_ref.type_signature, return_val) intrinsic = computation_constructing_utils.construct_map_or_apply( non_federated_fn, federated_comp) self.assertEqual(str(intrinsic), 'federated_map')