def test_after_aggregate_raises_mismatch_with_before_aggregate(self): good_init_type = computation_types.FederatedType( tf.float32, placements.SERVER) cf_types = canonical_form_utils.pack_initialize_comp_type_signature( good_init_type) next_type = _create_next_type_with_s1_type( computation_types.FederatedType(tf.float32, placements.SERVER)) good_before_broadcast_type = _create_before_broadcast_type_with_s1_type( computation_types.FederatedType(tf.float32, placements.SERVER)) cf_types = canonical_form_utils.pack_next_comp_type_signature( next_type, cf_types) cf_types = (canonical_form_utils. check_and_pack_before_broadcast_type_signature( good_before_broadcast_type, cf_types)) good_before_aggregate_type = _create_before_aggregate_with_c2_type( C2_TYPE) cf_types = (canonical_form_utils. check_and_pack_before_aggregate_type_signature( good_before_aggregate_type, cf_types)) bad_after_aggregate_type = _create_after_aggregate_with_s3_type( computation_types.FederatedType(tf.int32, placements.SERVER)) with self.assertRaisesRegex(TypeError, 'after_aggregate'): canonical_form_utils.check_and_pack_after_aggregate_type_signature( bad_after_aggregate_type, cf_types)
def test_init_passes_with_float_at_server(self): cf_types = canonical_form_utils.pack_initialize_comp_type_signature( computation_types.FederatedType(tf.float32, placements.SERVER)) self.assertIsInstance(cf_types['initialize_type'], computation_types.FederatedType) self.assertEqual(cf_types['initialize_type'].placement, placements.SERVER)
def test_after_aggregate_succeeds_and_packs(self): good_init_type = computation_types.FederatedType(tf.float32, placements.SERVER) cf_types = canonical_form_utils.pack_initialize_comp_type_signature( good_init_type) next_type = _create_next_type_with_s1_type(S1_TYPE) good_before_broadcast_type = _create_before_broadcast_type_with_s1_type( S1_TYPE) cf_types = canonical_form_utils.pack_next_comp_type_signature( next_type, cf_types) cf_types = ( canonical_form_utils.check_and_pack_before_broadcast_type_signature( good_before_broadcast_type, cf_types)) good_before_aggregate_type = _create_before_aggregate_with_c2_type(C2_TYPE) cf_types = ( canonical_form_utils.check_and_pack_before_aggregate_type_signature( good_before_aggregate_type, cf_types)) good_after_aggregate_type = _create_after_aggregate_with_s3_type(S3_TYPE) packed_types = ( canonical_form_utils.check_and_pack_after_aggregate_type_signature( good_after_aggregate_type, cf_types)) # Checking contents of the returned dict. self.assertEqual( packed_types['s4_type'], computation_types.FederatedType([S1_TYPE.member, S3_TYPE.member], placements.SERVER)) self.assertEqual( packed_types['c3_type'], computation_types.FederatedType([C1_TYPE.member, C2_TYPE.member], placements.CLIENTS)) self.assertEqual( packed_types['update_type'], computation_types.FunctionType(packed_types['s4_type'].member, packed_types['s5_type'].member))
def test_init_raises_non_federated_type(self): with self.assertRaisesRegex(TypeError, 'init'): canonical_form_utils.pack_initialize_comp_type_signature( tf.float32)