def test_after_aggregate_raises_mismatch_with_before_aggregate(self):
        good_init_type = computation_types.FederatedType(
            tf.float32, placements.SERVER)
        cf_types = canonical_form_utils.pack_initialize_comp_type_signature(
            good_init_type)
        next_type = _create_next_type_with_s1_type(
            computation_types.FederatedType(tf.float32, placements.SERVER))
        good_before_broadcast_type = _create_before_broadcast_type_with_s1_type(
            computation_types.FederatedType(tf.float32, placements.SERVER))
        cf_types = canonical_form_utils.pack_next_comp_type_signature(
            next_type, cf_types)
        cf_types = (canonical_form_utils.
                    check_and_pack_before_broadcast_type_signature(
                        good_before_broadcast_type, cf_types))
        good_before_aggregate_type = _create_before_aggregate_with_c2_type(
            C2_TYPE)
        cf_types = (canonical_form_utils.
                    check_and_pack_before_aggregate_type_signature(
                        good_before_aggregate_type, cf_types))
        bad_after_aggregate_type = _create_after_aggregate_with_s3_type(
            computation_types.FederatedType(tf.int32, placements.SERVER))

        with self.assertRaisesRegex(TypeError, 'after_aggregate'):
            canonical_form_utils.check_and_pack_after_aggregate_type_signature(
                bad_after_aggregate_type, cf_types)
 def test_init_passes_with_float_at_server(self):
     cf_types = canonical_form_utils.pack_initialize_comp_type_signature(
         computation_types.FederatedType(tf.float32, placements.SERVER))
     self.assertIsInstance(cf_types['initialize_type'],
                           computation_types.FederatedType)
     self.assertEqual(cf_types['initialize_type'].placement,
                      placements.SERVER)
예제 #3
0
 def test_after_aggregate_succeeds_and_packs(self):
   good_init_type = computation_types.FederatedType(tf.float32,
                                                    placements.SERVER)
   cf_types = canonical_form_utils.pack_initialize_comp_type_signature(
       good_init_type)
   next_type = _create_next_type_with_s1_type(S1_TYPE)
   good_before_broadcast_type = _create_before_broadcast_type_with_s1_type(
       S1_TYPE)
   cf_types = canonical_form_utils.pack_next_comp_type_signature(
       next_type, cf_types)
   cf_types = (
       canonical_form_utils.check_and_pack_before_broadcast_type_signature(
           good_before_broadcast_type, cf_types))
   good_before_aggregate_type = _create_before_aggregate_with_c2_type(C2_TYPE)
   cf_types = (
       canonical_form_utils.check_and_pack_before_aggregate_type_signature(
           good_before_aggregate_type, cf_types))
   good_after_aggregate_type = _create_after_aggregate_with_s3_type(S3_TYPE)
   packed_types = (
       canonical_form_utils.check_and_pack_after_aggregate_type_signature(
           good_after_aggregate_type, cf_types))
   # Checking contents of the returned dict.
   self.assertEqual(
       packed_types['s4_type'],
       computation_types.FederatedType([S1_TYPE.member, S3_TYPE.member],
                                       placements.SERVER))
   self.assertEqual(
       packed_types['c3_type'],
       computation_types.FederatedType([C1_TYPE.member, C2_TYPE.member],
                                       placements.CLIENTS))
   self.assertEqual(
       packed_types['update_type'],
       computation_types.FunctionType(packed_types['s4_type'].member,
                                      packed_types['s5_type'].member))
 def test_init_raises_non_federated_type(self):
     with self.assertRaisesRegex(TypeError, 'init'):
         canonical_form_utils.pack_initialize_comp_type_signature(
             tf.float32)