Example #1
0
def _can_extract_intrinsics_to_top_level_lambda(comp, uri):
    """Tests if the intrinsic for the given `uri` can be extracted.

  This currently maps identically to: the called intrinsics we intend to hoist
  don't close over any intermediate variables. That is, any variables other than
  potentiall the top-level parameter the computation itself declares.

  Args:
    comp: The `building_blocks.Lambda` to test. The names of lambda parameters
      and block variables in `comp` must be unique.
    uri: A Python `list` of URI of intrinsics.

  Returns:
    `True` if the intrinsic can be extracted, otherwise `False`.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(uri, list)
    for x in uri:
        py_typecheck.check_type(x, str)
    tree_analysis.check_has_unique_names(comp)

    intrinsics = _get_called_intrinsics(comp, uri)
    return all(
        tree_analysis.contains_no_unbound_references(x, comp.parameter_name)
        for x in intrinsics)
Example #2
0
 def test_returns_false(self):
     ref = building_blocks.Reference('a', tf.int32)
     fn = building_blocks.Lambda('b', tf.int32, ref)
     self.assertFalse(tree_analysis.contains_no_unbound_references(fn))
Example #3
0
 def test_returns_true_with_excluded_reference(self):
     ref = building_blocks.Reference('a', tf.int32)
     fn = building_blocks.Lambda('b', tf.int32, ref)
     self.assertTrue(
         tree_analysis.contains_no_unbound_references(fn, excluding='a'))
Example #4
0
 def test_returns_true(self):
     ref = building_blocks.Reference('a', tf.int32)
     fn = building_blocks.Lambda(ref.name, ref.type_signature, ref)
     self.assertTrue(tree_analysis.contains_no_unbound_references(fn))
Example #5
0
 def test_raises_type_error_with_int_excluding(self):
     ref = building_blocks.Reference('a', tf.int32)
     fn = building_blocks.Lambda(ref.name, ref.type_signature, ref)
     with self.assertRaises(TypeError):
         tree_analysis.contains_no_unbound_references(fn, 1)
Example #6
0
 def test_raises_type_error_with_none_tree(self):
     with self.assertRaises(TypeError):
         tree_analysis.contains_no_unbound_references(None)
Example #7
0
def _extract_intrinsics_to_top_level_lambda(comp, uri):
    r"""Extracts intrinsics in `comp` for the given `uri`.

  This transformation creates an AST such that all the called intrinsics for the
  given `uri` in body of the `building_blocks.Block` returned by the top level
  lambda have been extracted to the top level lambda and replaced by selections
  from a reference to the constructed variable.

                       Lambda
                       |
                       Block
                      /     \
        [x=Struct, ...]       Comp
           |
           [Call,                  Call                   Call]
           /    \                 /    \                 /    \
  Intrinsic      Comp    Intrinsic      Comp    Intrinsic      Comp

  The order of the extracted called intrinsics matches the order of `uri`.

  Note: if this function is passed an AST which contains nested called
  intrinsics, it will fail, as it will mutate the subcomputation containing
  the lower-level called intrinsics on the way back up the tree.

  Args:
    comp: The `building_blocks.Lambda` to transform. The names of lambda
      parameters and block variables in `comp` must be unique.
    uri: A URI of an intrinsic.

  Returns:
    A new computation with the transformation applied or the original `comp`.

  Raises:
    ValueError: If all the intrinsics for the given `uri` in `comp` are not
      exclusively bound by `comp`.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(uri, list)
    for x in uri:
        py_typecheck.check_type(x, str)
    tree_analysis.check_has_unique_names(comp)

    name_generator = building_block_factory.unique_name_generator(comp)

    intrinsics = _get_called_intrinsics(comp, uri)
    for intrinsic in intrinsics:
        if not tree_analysis.contains_no_unbound_references(
                intrinsic, comp.parameter_name):
            raise ValueError(
                'Expected a computation which binds all the references in all the '
                'intrinsic with the uri: {}.'.format(uri))
    if len(intrinsics) > 1:
        order = {}
        for index, element in enumerate(uri):
            if element not in order:
                order[element] = index
        intrinsics = sorted(intrinsics, key=lambda x: order[x.function.uri])
        extracted_comp = building_blocks.Struct(intrinsics)
    else:
        extracted_comp = intrinsics[0]
    ref_name = next(name_generator)
    ref_type = computation_types.to_type(extracted_comp.type_signature)
    ref = building_blocks.Reference(ref_name, ref_type)

    def _should_transform(comp):
        return building_block_analysis.is_called_intrinsic(comp, uri)

    def _transform(comp):
        if not _should_transform(comp):
            return comp, False
        if len(intrinsics) > 1:
            index = intrinsics.index(comp)
            comp = building_blocks.Selection(ref, index=index)
            return comp, True
        else:
            return ref, True

    comp, _ = transformation_utils.transform_postorder(comp, _transform)
    comp = _insert_comp_in_top_level_lambda(comp,
                                            name=ref.name,
                                            comp_to_insert=extracted_comp)
    return comp, True