def _prepare_for_rebinding(bb): """Replaces `bb` with semantically equivalent version for rebinding.""" bb = compiler.normalize_all_equal_bit(bb) bb, _ = tree_transformations.remove_mapped_or_applied_identity(bb) bb = transformations.to_call_dominant(bb) bb, _ = tree_transformations.remove_unused_block_locals(bb) return bb
def _prepare_for_rebinding(bb): """Replaces `bb` with semantically equivalent version for rebinding.""" all_equal_normalized = transformations.normalize_all_equal_bit(bb) identities_removed, _ = tree_transformations.remove_mapped_or_applied_identity( all_equal_normalized) for_rebind, _ = compiler_transformations.prepare_for_rebinding( identities_removed) return for_rebind
def test_removes_intrinsic(self, uri, factory): call = factory(parameter_name='a') comp = call transformed_comp, modified = tree_transformations.remove_mapped_or_applied_identity( comp) self.assertEqual(comp.compact_representation(), '{}(<(a -> a),data>)'.format(uri)) self.assertEqual(transformed_comp.compact_representation(), 'data') self.assertEqual(transformed_comp.type_signature, comp.type_signature) self.assertTrue(modified)
def test_does_not_remove_whimsy_intrinsic(self): comp = building_block_test_utils.create_whimsy_called_intrinsic( parameter_name='a') transformed_comp, modified = tree_transformations.remove_mapped_or_applied_identity( comp) self.assertEqual(transformed_comp.compact_representation(), comp.compact_representation()) self.assertEqual(transformed_comp.compact_representation(), 'intrinsic(a)') self.assertEqual(transformed_comp.type_signature, comp.type_signature) self.assertFalse(modified)
def test_does_not_remove_called_lambda(self): fn = building_block_test_utils.create_identity_function('a', tf.int32) arg = building_blocks.Data('data', tf.int32) call = building_blocks.Call(fn, arg) comp = call transformed_comp, modified = tree_transformations.remove_mapped_or_applied_identity( comp) self.assertEqual(transformed_comp.compact_representation(), comp.compact_representation()) self.assertEqual(transformed_comp.compact_representation(), '(a -> a)(data)') self.assertEqual(transformed_comp.type_signature, comp.type_signature) self.assertFalse(modified)
def test_removes_nested_federated_map(self): called_intrinsic = building_block_test_utils.create_whimsy_called_federated_map( parameter_name='a') block = building_block_test_utils.create_whimsy_block( called_intrinsic, variable_name='b') comp = block transformed_comp, modified = tree_transformations.remove_mapped_or_applied_identity( comp) self.assertEqual(comp.compact_representation(), '(let b=data in federated_map(<(a -> a),data>))') self.assertEqual(transformed_comp.compact_representation(), '(let b=data in data)') self.assertEqual(transformed_comp.type_signature, comp.type_signature) self.assertTrue(modified)
def test_removes_chained_federated_maps(self): fn = building_block_test_utils.create_identity_function('a', tf.int32) arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) arg = building_blocks.Data('data', arg_type) call = _create_chained_whimsy_federated_maps([fn, fn], arg) comp = call transformed_comp, modified = tree_transformations.remove_mapped_or_applied_identity( comp) self.assertEqual( comp.compact_representation(), 'federated_map(<(a -> a),federated_map(<(a -> a),data>)>)') self.assertEqual(transformed_comp.compact_representation(), 'data') self.assertEqual(transformed_comp.type_signature, comp.type_signature) self.assertTrue(modified)
def test_removes_federated_map_with_named_result(self): parameter_type = [('a', tf.int32), ('b', tf.int32)] fn = building_block_test_utils.create_identity_function( 'c', parameter_type) arg_type = computation_types.FederatedType(parameter_type, placements.CLIENTS) arg = building_blocks.Data('data', arg_type) call = building_block_factory.create_federated_map(fn, arg) comp = call transformed_comp, modified = tree_transformations.remove_mapped_or_applied_identity( comp) self.assertEqual(comp.compact_representation(), 'federated_map(<(c -> c),data>)') self.assertEqual(transformed_comp.compact_representation(), 'data') self.assertEqual(transformed_comp.type_signature, comp.type_signature) self.assertTrue(modified)
def transformation_fn(x): x, _ = tree_transformations.remove_mapped_or_applied_identity(x) x, _ = tree_transformations.inline_block_locals(x) x, _ = tree_transformations.replace_selection_from_tuple_with_element(x) return x
def transformation_fn(x): x, _ = tree_transformations.remove_mapped_or_applied_identity(x) return x
def transformation_fn(x): x, _ = tree_transformations.uniquify_reference_names(x) x, _ = tree_transformations.inline_block_locals(x) x, _ = tree_transformations.remove_mapped_or_applied_identity(x) return x
def transformation_fn(x): x, _ = tree_transformations.remove_mapped_or_applied_identity(x) return transformations.to_call_dominant(x)
def transformation_fn(x): x, _ = tree_transformations.uniquify_reference_names(x) x, _ = tree_transformations.remove_mapped_or_applied_identity(x) x = transformations.to_call_dominant(x) return x
def test_raises_type_error(self): with self.assertRaises(TypeError): tree_transformations.remove_mapped_or_applied_identity(None)