Пример #1
0
 def _generate_simple_tensorflow(comp):
     tf_parser_callable = tree_to_cc_transformations.TFParser()
     comp, _ = tree_transformations.insert_called_tf_identity_at_leaves(
         comp)
     comp, _ = transformation_utils.transform_postorder(
         comp, tf_parser_callable)
     return comp
Пример #2
0
def _generate_simple_tensorflow(comp):
    """Naively generates TensorFlow to represent `comp`."""
    tf_parser_callable = tree_to_cc_transformations.TFParser()
    comp, _ = tree_transformations.insert_called_tf_identity_at_leaves(comp)
    comp, _ = transformation_utils.transform_postorder(comp,
                                                       tf_parser_callable)
    return comp
def parse_tff_to_tf(comp):
  comp, _ = tree_transformations.insert_called_tf_identity_at_leaves(comp)
  parser_callable = tree_to_cc_transformations.TFParser()
  comp, _ = tree_transformations.replace_called_lambda_with_block(comp)
  comp, _ = tree_transformations.inline_block_locals(comp)
  comp, _ = tree_transformations.replace_selection_from_tuple_with_element(comp)
  new_comp, transformed = transformation_utils.transform_postorder(
      comp, parser_callable)
  return new_comp, transformed
Пример #4
0
 def transform(self, local_function):
   if not self.should_transform(local_function):
     return local_function, False
   refs_removed, _ = remove_called_lambdas_and_blocks(local_function)
   parsed_to_tf, _ = remove_duplicate_called_graphs(refs_removed)
   if parsed_to_tf.is_compiled_computation() or (
       parsed_to_tf.is_call() and
       parsed_to_tf.function.is_compiled_computation()):
     return parsed_to_tf, True
   # TODO(b/146430051): We should only end up in this case if
   # `remove_called_lambdas_and_blocks` above is in its failure mode, IE,
   # failing to resolve references due to too-deep indirection; we should
   # remove this extra case and simply raise if we fail here when we fix the
   # attached bug.
   called_graphs_inserted, _ = tree_transformations.insert_called_tf_identity_at_leaves(
       parsed_to_tf)
   compiled_comp, _ = transformation_utils.transform_postorder(
       called_graphs_inserted, self._naive_tf_parser)
   return compiled_comp, True
Пример #5
0
def remove_duplicate_called_graphs(comp):
  """Deduplicates called graphs for a subset of TFF AST constructs.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` whose called
      graphs we wish to deduplicate, according to `tree_analysis.trees_equal`.
      For `comp` to be eligible here, it must be either a lambda itself whose
      body contains no lambdas or blocks, or another computation containing no
      lambdas or blocks. This restriction is necessary because
      `remove_duplicate_called_graphs` makes no effort to ensure that it is not
      pulling references out of their defining scope, except for the case where
      `comp` is a lambda itself. This function exits early and logs a warning if
      this assumption is violated. Additionally, `comp` must contain only
      computations which can be represented in TensorFlow, IE, satisfy the type
      restriction in `type_analysis.is_tensorflow_compatible_type`.

  Returns:
    Either a called instance of `building_blocks.CompiledComputation` or a
    `building_blocks.CompiledComputation` itself, depending on whether `comp`
    is of non-functional or functional type respectively. Additionally, returns
    a boolean to match the `transformation_utils.TransformSpec` pattern.
  """
  py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
  tree_analysis.check_has_unique_names(comp)
  name_generator = building_block_factory.unique_name_generator(comp)
  if comp.is_lambda():
    comp_to_check = comp.result
  else:
    comp_to_check = comp
  if tree_analysis.contains_types(comp_to_check, (
      building_blocks.Block,
      building_blocks.Lambda,
  )):
    logging.warning(
        'The preprocessors have failed to remove called lambdas '
        'and blocks; falling back to less efficient, but '
        'guaranteed, TensorFlow generation with computation %s.', comp)
    return comp, False

  leaf_called_graphs = []

  def _pack_called_graphs_into_block(inner_comp):
    """Packs deduplicated bindings to called graphs in `leaf_called_graphs`."""
    if inner_comp.is_call() and inner_comp.function.is_compiled_computation():
      for (name, x) in leaf_called_graphs:
        if tree_analysis.trees_equal(x, inner_comp):
          return building_blocks.Reference(name,
                                           inner_comp.type_signature), True
      new_name = next(name_generator)
      leaf_called_graphs.append((new_name, inner_comp))
      return building_blocks.Reference(new_name,
                                       inner_comp.type_signature), True

    return inner_comp, False

  if comp.is_lambda():
    transformed_result, _ = transformation_utils.transform_postorder(
        comp.result, _pack_called_graphs_into_block)
    packed_into_block = building_blocks.Block(leaf_called_graphs,
                                              transformed_result)
    parsed, _ = create_tensorflow_representing_block(packed_into_block)
    tff_func = building_blocks.Lambda(comp.parameter_name, comp.parameter_type,
                                      parsed)
    tf_parser_callable = tree_to_cc_transformations.TFParser()
    comp, _ = tree_transformations.insert_called_tf_identity_at_leaves(tff_func)
    tf_generated, _ = transformation_utils.transform_postorder(
        comp, tf_parser_callable)
  else:
    transformed_result, _ = transformation_utils.transform_postorder(
        comp, _pack_called_graphs_into_block)
    packed_into_block = building_blocks.Block(leaf_called_graphs,
                                              transformed_result)
    tf_generated, _ = create_tensorflow_representing_block(packed_into_block)
  return tf_generated, True