def _can_extract_intrinsics_to_top_level_lambda(comp, uri): """Tests if the intrinsic for the given `uri` can be extracted. This currently maps identically to: the called intrinsics we intend to hoist don't close over any intermediate variables. That is, any variables other than potentiall the top-level parameter the computation itself declares. Args: comp: The `building_blocks.Lambda` to test. The names of lambda parameters and block variables in `comp` must be unique. uri: A Python `list` of URI of intrinsics. Returns: `True` if the intrinsic can be extracted, otherwise `False`. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(uri, list) for x in uri: py_typecheck.check_type(x, str) tree_analysis.check_has_unique_names(comp) intrinsics = _get_called_intrinsics(comp, uri) return all( tree_analysis.contains_no_unbound_references(x, comp.parameter_name) for x in intrinsics)
def test_returns_false(self): ref = building_blocks.Reference('a', tf.int32) fn = building_blocks.Lambda('b', tf.int32, ref) self.assertFalse(tree_analysis.contains_no_unbound_references(fn))
def test_returns_true_with_excluded_reference(self): ref = building_blocks.Reference('a', tf.int32) fn = building_blocks.Lambda('b', tf.int32, ref) self.assertTrue( tree_analysis.contains_no_unbound_references(fn, excluding='a'))
def test_returns_true(self): ref = building_blocks.Reference('a', tf.int32) fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) self.assertTrue(tree_analysis.contains_no_unbound_references(fn))
def test_raises_type_error_with_int_excluding(self): ref = building_blocks.Reference('a', tf.int32) fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) with self.assertRaises(TypeError): tree_analysis.contains_no_unbound_references(fn, 1)
def test_raises_type_error_with_none_tree(self): with self.assertRaises(TypeError): tree_analysis.contains_no_unbound_references(None)
def _extract_intrinsics_to_top_level_lambda(comp, uri): r"""Extracts intrinsics in `comp` for the given `uri`. This transformation creates an AST such that all the called intrinsics for the given `uri` in body of the `building_blocks.Block` returned by the top level lambda have been extracted to the top level lambda and replaced by selections from a reference to the constructed variable. Lambda | Block / \ [x=Struct, ...] Comp | [Call, Call Call] / \ / \ / \ Intrinsic Comp Intrinsic Comp Intrinsic Comp The order of the extracted called intrinsics matches the order of `uri`. Note: if this function is passed an AST which contains nested called intrinsics, it will fail, as it will mutate the subcomputation containing the lower-level called intrinsics on the way back up the tree. Args: comp: The `building_blocks.Lambda` to transform. The names of lambda parameters and block variables in `comp` must be unique. uri: A URI of an intrinsic. Returns: A new computation with the transformation applied or the original `comp`. Raises: ValueError: If all the intrinsics for the given `uri` in `comp` are not exclusively bound by `comp`. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(uri, list) for x in uri: py_typecheck.check_type(x, str) tree_analysis.check_has_unique_names(comp) name_generator = building_block_factory.unique_name_generator(comp) intrinsics = _get_called_intrinsics(comp, uri) for intrinsic in intrinsics: if not tree_analysis.contains_no_unbound_references( intrinsic, comp.parameter_name): raise ValueError( 'Expected a computation which binds all the references in all the ' 'intrinsic with the uri: {}.'.format(uri)) if len(intrinsics) > 1: order = {} for index, element in enumerate(uri): if element not in order: order[element] = index intrinsics = sorted(intrinsics, key=lambda x: order[x.function.uri]) extracted_comp = building_blocks.Struct(intrinsics) else: extracted_comp = intrinsics[0] ref_name = next(name_generator) ref_type = computation_types.to_type(extracted_comp.type_signature) ref = building_blocks.Reference(ref_name, ref_type) def _should_transform(comp): return building_block_analysis.is_called_intrinsic(comp, uri) def _transform(comp): if not _should_transform(comp): return comp, False if len(intrinsics) > 1: index = intrinsics.index(comp) comp = building_blocks.Selection(ref, index=index) return comp, True else: return ref, True comp, _ = transformation_utils.transform_postorder(comp, _transform) comp = _insert_comp_in_top_level_lambda(comp, name=ref.name, comp_to_insert=extracted_comp) return comp, True