def test_passes_with_federated_map(self):
   intrinsic = computation_building_blocks.Intrinsic(
       intrinsic_defs.FEDERATED_MAP.uri,
       computation_types.FunctionType([
           computation_types.FunctionType(tf.int32, tf.float32),
           computation_types.FederatedType(tf.int32, placements.CLIENTS)
       ], computation_types.FederatedType(tf.float32, placements.CLIENTS)))
   tree_analysis.check_intrinsics_whitelisted_for_reduction(intrinsic)
Exemplo n.º 2
0
    def test_raises_with_federated_mean(self):
        intrinsic = computation_building_blocks.Intrinsic(
            intrinsic_defs.FEDERATED_MEAN.uri,
            computation_types.FunctionType(
                computation_types.FederatedType(tf.int32, placements.CLIENTS),
                computation_types.FederatedType(tf.int32, placements.SERVER)))

        with self.assertRaisesRegex(ValueError,
                                    intrinsic.compact_representation()):
            tree_analysis.check_intrinsics_whitelisted_for_reduction(intrinsic)
    def test_generic_divide_reduces(self):
        uri = intrinsic_defs.GENERIC_DIVIDE.uri
        context_stack = context_stack_impl.context_stack
        comp = computation_building_blocks.Intrinsic(
            uri,
            computation_types.FunctionType([tf.float32, tf.float32],
                                           tf.float32))

        count_before_reduction = _count_intrinsics(comp, uri)
        reduced, modified = value_transformations.replace_all_intrinsics_with_bodies(
            comp, context_stack)
        count_after_reduction = _count_intrinsics(reduced, uri)

        self.assertGreater(count_before_reduction, 0)
        self.assertEqual(count_after_reduction, 0)
        tree_analysis.check_intrinsics_whitelisted_for_reduction(reduced)
        self.assertTrue(modified)
Exemplo n.º 4
0
 def test_raises_on_none(self):
     with self.assertRaises(TypeError):
         tree_analysis.check_intrinsics_whitelisted_for_reduction(None)