def _prepare_for_rebinding(bb): """Replaces `bb` with semantically equivalent version for rebinding.""" all_equal_normalized = transformations.normalize_all_equal_bit(bb) identities_removed, _ = tree_transformations.remove_mapped_or_applied_identity( all_equal_normalized) for_rebind, _ = compiler_transformations.prepare_for_rebinding( identities_removed) return for_rebind
def test_converts_not_all_equal_at_server_reference_to_equal(self): fed_type_not_all_equal = computation_types.FederatedType( tf.int32, placements.SERVER, all_equal=False) normalized_comp = transformations.normalize_all_equal_bit( building_blocks.Reference('x', fed_type_not_all_equal)) self.assertEqual( normalized_comp.type_signature, computation_types.FederatedType( tf.int32, placements.SERVER, all_equal=True)) self.assertIsInstance(normalized_comp, building_blocks.Reference) self.assertEqual(str(normalized_comp), 'x')
def test_converts_all_equal_at_clients_lambda_parameter_to_not_equal(self): fed_type_all_equal = computation_types.FederatedType( tf.int32, placements.CLIENTS, all_equal=True) normalized_fed_type = computation_types.FederatedType( tf.int32, placements.CLIENTS) ref = building_blocks.Reference('x', fed_type_all_equal) lam = building_blocks.Lambda('x', fed_type_all_equal, ref) normalized_lambda = transformations.normalize_all_equal_bit(lam) self.assertEqual( lam.type_signature, computation_types.FunctionType(fed_type_all_equal, fed_type_all_equal)) self.assertIsInstance(normalized_lambda, building_blocks.Lambda) self.assertEqual(str(normalized_lambda), '(x -> x)') self.assertEqual( normalized_lambda.type_signature, computation_types.FunctionType(normalized_fed_type, normalized_fed_type))
def test_converts_federated_map_all_equal_to_federated_map(self): fed_type_all_equal = computation_types.FederatedType( tf.int32, placements.CLIENTS, all_equal=True) normalized_fed_type = computation_types.FederatedType( tf.int32, placements.CLIENTS) int_ref = building_blocks.Reference('x', tf.int32) int_identity = building_blocks.Lambda('x', tf.int32, int_ref) federated_int_ref = building_blocks.Reference('y', fed_type_all_equal) called_federated_map_all_equal = building_block_factory.create_federated_map_all_equal( int_identity, federated_int_ref) normalized_federated_map = mapreduce_transformations.normalize_all_equal_bit( called_federated_map_all_equal) self.assertEqual(called_federated_map_all_equal.function.uri, intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri) self.assertIsInstance(normalized_federated_map, building_blocks.Call) self.assertIsInstance(normalized_federated_map.function, building_blocks.Intrinsic) self.assertEqual(normalized_federated_map.function.uri, intrinsic_defs.FEDERATED_MAP.uri) self.assertEqual(normalized_federated_map.type_signature, normalized_fed_type)
def test_raises_on_none(self): with self.assertRaises(TypeError): mapreduce_transformations.normalize_all_equal_bit(None)