def test_before_aggregate_fails_mismatch_with_before_broadcast_type(self): cf_types = {'initialize_type': INIT_TYPE} next_type = _create_next_type_with_s1_type(S1_TYPE) good_before_broadcast_type = _create_before_broadcast_type_with_s1_type( S1_TYPE) cf_types = canonical_form_utils.pack_next_comp_type_signature( next_type, cf_types) cf_types = (canonical_form_utils. check_and_pack_before_broadcast_type_signature( good_before_broadcast_type, cf_types)) bad_before_aggregate_type = _create_before_aggregate_with_c2_type( computation_types.FederatedType(tf.int32, placements.CLIENTS)) with self.assertRaisesRegex(TypeError, 'before_aggregate'): canonical_form_utils.check_and_pack_before_aggregate_type_signature( bad_before_aggregate_type, cf_types)
def test_before_broadcast_succeeds_match_with_next_type(self): cf_types = {'initialize_type': INIT_TYPE} next_type = _create_next_type_with_s1_type(S1_TYPE) cf_types = canonical_form_utils.pack_next_comp_type_signature( next_type, cf_types) good_before_broadcast_type = _create_before_broadcast_type_with_s1_type( S1_TYPE) packed_types = (canonical_form_utils. check_and_pack_before_broadcast_type_signature( good_before_broadcast_type, cf_types)) # Checking contents of the returned dict. self.assertEqual( packed_types['s2_type'], computation_types.FederatedType(C2_TYPE.member, placements.SERVER)) self.assertEqual( packed_types['prepare_type'], computation_types.FunctionType(S1_TYPE.member, S2_TYPE.member))
def test_before_aggregate_succeeds_and_packs(self): cf_types = {'initialize_type': INIT_TYPE} next_type = _create_next_type_with_s1_type(S1_TYPE) good_before_broadcast_type = _create_before_broadcast_type_with_s1_type( S1_TYPE) cf_types = canonical_form_utils.pack_next_comp_type_signature( next_type, cf_types) cf_types = ( canonical_form_utils.check_and_pack_before_broadcast_type_signature( good_before_broadcast_type, cf_types)) good_before_aggregate_type = _create_before_aggregate_with_c2_type(C2_TYPE) packed_types = ( canonical_form_utils.check_and_pack_before_aggregate_type_signature( good_before_aggregate_type, cf_types)) # Checking contents of the returned dict. self.assertEqual(packed_types['c5_type'], C5_TYPE) self.assertEqual(packed_types['zero_type'].result, ZERO_TYPE) self.assertEqual(packed_types['accumulate_type'], ACCUMULATE_TYPE) self.assertEqual(packed_types['merge_type'], MERGE_TYPE) self.assertEqual(packed_types['report_type'], REPORT_TYPE)
def test_after_aggregate_succeeds_and_packs(self): good_init_type = computation_types.FederatedType( tf.float32, placements.SERVER) cf_types = canonical_form_utils.pack_initialize_comp_type_signature( good_init_type) next_type = _create_next_type_with_s1_type(S1_TYPE) good_before_broadcast_type = _create_before_broadcast_type_with_s1_type( S1_TYPE) cf_types = canonical_form_utils.pack_next_comp_type_signature( next_type, cf_types) cf_types = (canonical_form_utils. check_and_pack_before_broadcast_type_signature( good_before_broadcast_type, cf_types)) good_before_aggregate_type = _create_before_aggregate_with_c2_type( C2_TYPE) cf_types = (canonical_form_utils. check_and_pack_before_aggregate_type_signature( good_before_aggregate_type, cf_types)) good_after_aggregate_type = _create_after_aggregate_with_s3_type( S3_TYPE) packed_types = ( canonical_form_utils.check_and_pack_after_aggregate_type_signature( good_after_aggregate_type, cf_types)) # Checking contents of the returned dict. self.assertEqual( packed_types['s4_type'], computation_types.FederatedType([S1_TYPE.member, S3_TYPE.member], placements.SERVER)) self.assertEqual( packed_types['c3_type'], computation_types.FederatedType([C1_TYPE.member, C2_TYPE.member], placements.CLIENTS)) self.assertEqual( packed_types['update_type'], computation_types.FunctionType(packed_types['s4_type'].member, packed_types['s5_type'].member))
def test_after_aggregate_raises_mismatch_with_before_aggregate(self): good_init_type = computation_types.FederatedType(tf.float32, placements.SERVER) cf_types = canonical_form_utils.pack_initialize_comp_type_signature( good_init_type) next_type = _create_next_type_with_s1_type( computation_types.FederatedType(tf.float32, placements.SERVER)) good_before_broadcast_type = _create_before_broadcast_type_with_s1_type( computation_types.FederatedType(tf.float32, placements.SERVER)) cf_types = canonical_form_utils.pack_next_comp_type_signature( next_type, cf_types) cf_types = ( canonical_form_utils.check_and_pack_before_broadcast_type_signature( good_before_broadcast_type, cf_types)) good_before_aggregate_type = _create_before_aggregate_with_c2_type(C2_TYPE) cf_types = ( canonical_form_utils.check_and_pack_before_aggregate_type_signature( good_before_aggregate_type, cf_types)) bad_after_aggregate_type = _create_after_aggregate_with_s3_type( computation_types.FederatedType(tf.int32, placements.SERVER)) with self.assertRaisesRegex(TypeError, 'after_aggregate'): canonical_form_utils.check_and_pack_after_aggregate_type_signature( bad_after_aggregate_type, cf_types)
def test_next_fails_mismatch_with_init_type(self): cf_types = {'initialize_type': INIT_TYPE} next_type = _create_next_type_with_s1_type( computation_types.FederatedType(tf.int32, placements.SERVER)) with self.assertRaisesRegex(TypeError, 'next'): canonical_form_utils.pack_next_comp_type_signature(next_type, cf_types)