예제 #1
0
 def test_before_aggregate_fails_mismatch_with_before_broadcast_type(self):
     cf_types = {'initialize_type': INIT_TYPE}
     next_type = _create_next_type_with_s1_type(S1_TYPE)
     good_before_broadcast_type = _create_before_broadcast_type_with_s1_type(
         S1_TYPE)
     cf_types = canonical_form_utils.pack_next_comp_type_signature(
         next_type, cf_types)
     cf_types = (canonical_form_utils.
                 check_and_pack_before_broadcast_type_signature(
                     good_before_broadcast_type, cf_types))
     bad_before_aggregate_type = _create_before_aggregate_with_c2_type(
         computation_types.FederatedType(tf.int32, placements.CLIENTS))
     with self.assertRaisesRegex(TypeError, 'before_aggregate'):
         canonical_form_utils.check_and_pack_before_aggregate_type_signature(
             bad_before_aggregate_type, cf_types)
 def test_before_broadcast_succeeds_match_with_next_type(self):
     cf_types = {'initialize_type': INIT_TYPE}
     next_type = _create_next_type_with_s1_type(S1_TYPE)
     cf_types = canonical_form_utils.pack_next_comp_type_signature(
         next_type, cf_types)
     good_before_broadcast_type = _create_before_broadcast_type_with_s1_type(
         S1_TYPE)
     packed_types = (canonical_form_utils.
                     check_and_pack_before_broadcast_type_signature(
                         good_before_broadcast_type, cf_types))
     # Checking contents of the returned dict.
     self.assertEqual(
         packed_types['s2_type'],
         computation_types.FederatedType(C2_TYPE.member, placements.SERVER))
     self.assertEqual(
         packed_types['prepare_type'],
         computation_types.FunctionType(S1_TYPE.member, S2_TYPE.member))
예제 #3
0
  def test_before_aggregate_succeeds_and_packs(self):
    cf_types = {'initialize_type': INIT_TYPE}
    next_type = _create_next_type_with_s1_type(S1_TYPE)
    good_before_broadcast_type = _create_before_broadcast_type_with_s1_type(
        S1_TYPE)
    cf_types = canonical_form_utils.pack_next_comp_type_signature(
        next_type, cf_types)
    cf_types = (
        canonical_form_utils.check_and_pack_before_broadcast_type_signature(
            good_before_broadcast_type, cf_types))
    good_before_aggregate_type = _create_before_aggregate_with_c2_type(C2_TYPE)
    packed_types = (
        canonical_form_utils.check_and_pack_before_aggregate_type_signature(
            good_before_aggregate_type, cf_types))

    # Checking contents of the returned dict.
    self.assertEqual(packed_types['c5_type'], C5_TYPE)
    self.assertEqual(packed_types['zero_type'].result, ZERO_TYPE)
    self.assertEqual(packed_types['accumulate_type'], ACCUMULATE_TYPE)
    self.assertEqual(packed_types['merge_type'], MERGE_TYPE)
    self.assertEqual(packed_types['report_type'], REPORT_TYPE)
 def test_after_aggregate_succeeds_and_packs(self):
     good_init_type = computation_types.FederatedType(
         tf.float32, placements.SERVER)
     cf_types = canonical_form_utils.pack_initialize_comp_type_signature(
         good_init_type)
     next_type = _create_next_type_with_s1_type(S1_TYPE)
     good_before_broadcast_type = _create_before_broadcast_type_with_s1_type(
         S1_TYPE)
     cf_types = canonical_form_utils.pack_next_comp_type_signature(
         next_type, cf_types)
     cf_types = (canonical_form_utils.
                 check_and_pack_before_broadcast_type_signature(
                     good_before_broadcast_type, cf_types))
     good_before_aggregate_type = _create_before_aggregate_with_c2_type(
         C2_TYPE)
     cf_types = (canonical_form_utils.
                 check_and_pack_before_aggregate_type_signature(
                     good_before_aggregate_type, cf_types))
     good_after_aggregate_type = _create_after_aggregate_with_s3_type(
         S3_TYPE)
     packed_types = (
         canonical_form_utils.check_and_pack_after_aggregate_type_signature(
             good_after_aggregate_type, cf_types))
     # Checking contents of the returned dict.
     self.assertEqual(
         packed_types['s4_type'],
         computation_types.FederatedType([S1_TYPE.member, S3_TYPE.member],
                                         placements.SERVER))
     self.assertEqual(
         packed_types['c3_type'],
         computation_types.FederatedType([C1_TYPE.member, C2_TYPE.member],
                                         placements.CLIENTS))
     self.assertEqual(
         packed_types['update_type'],
         computation_types.FunctionType(packed_types['s4_type'].member,
                                        packed_types['s5_type'].member))
예제 #5
0
  def test_after_aggregate_raises_mismatch_with_before_aggregate(self):
    good_init_type = computation_types.FederatedType(tf.float32,
                                                     placements.SERVER)
    cf_types = canonical_form_utils.pack_initialize_comp_type_signature(
        good_init_type)
    next_type = _create_next_type_with_s1_type(
        computation_types.FederatedType(tf.float32, placements.SERVER))
    good_before_broadcast_type = _create_before_broadcast_type_with_s1_type(
        computation_types.FederatedType(tf.float32, placements.SERVER))
    cf_types = canonical_form_utils.pack_next_comp_type_signature(
        next_type, cf_types)
    cf_types = (
        canonical_form_utils.check_and_pack_before_broadcast_type_signature(
            good_before_broadcast_type, cf_types))
    good_before_aggregate_type = _create_before_aggregate_with_c2_type(C2_TYPE)
    cf_types = (
        canonical_form_utils.check_and_pack_before_aggregate_type_signature(
            good_before_aggregate_type, cf_types))
    bad_after_aggregate_type = _create_after_aggregate_with_s3_type(
        computation_types.FederatedType(tf.int32, placements.SERVER))

    with self.assertRaisesRegex(TypeError, 'after_aggregate'):
      canonical_form_utils.check_and_pack_after_aggregate_type_signature(
          bad_after_aggregate_type, cf_types)
예제 #6
0
 def test_next_fails_mismatch_with_init_type(self):
   cf_types = {'initialize_type': INIT_TYPE}
   next_type = _create_next_type_with_s1_type(
       computation_types.FederatedType(tf.int32, placements.SERVER))
   with self.assertRaisesRegex(TypeError, 'next'):
     canonical_form_utils.pack_next_comp_type_signature(next_type, cf_types)