示例#1
0
def _graph_callable_internal(func, shape_and_dtypes):
  """Defines and returns a template version of func.

  Under the hood we make two function objects, each wrapping a different version
  of the graph-mode code. One version immediately runs variable initialization
  before making the variable's Tensors available for use, while the other
  version replaces the Variables with placeholders which become function
  arguments and get the current variable's value.

  Limitations in (2) and (4) are because this does not implement a graph-mode
  Variable class which has a convert_to_tensor(as_ref=True) method and a
  initialized_value method. This is fixable.

  Args:
    func: The tfe Python function to compile.
    shape_and_dtypes: A list of type ShapeAndDtype.

  Raises:
    ValueError: If any one of func's outputs is not a Tensor.

  Returns:
    Callable graph object.
  """
  container = tf_ops.get_default_graph()._container  # pylint: disable=protected-access
  container_prefix = tf_ops.get_default_graph()._container_prefix  # pylint: disable=protected-access
  with context.graph_mode():
    # This graph will store both the initialization and the call version of the
    # wrapped function. It will later be used by the backprop code to build the
    # backprop graph, if necessary.
    captures = {}
    tmp_graph = function.CapturingGraph(captures)
    # Inherit the container from the original graph to create resources at user
    # expected containers. Also inherits the container prefix, since this is
    # used for error checking when isolating Eager execution (the container
    # prefix at creation must match the container prefix when used, and
    # variables returned from the graph callable will be used in the outside
    # context).
    tmp_graph._container = container  # pylint: disable=protected-access
    tmp_graph._container_prefix = container_prefix  # pylint: disable=protected-access
    with tmp_graph.as_default():
      # Placeholders for the non-variable inputs.
      func_inputs = _get_graph_callable_inputs(shape_and_dtypes)
      func_num_args = len(tf_inspect.getargspec(func).args)
      if len(func_inputs) != func_num_args:
        raise TypeError("The number of arguments accepted by the decorated "
                        "function `%s` (%d) must match the number of "
                        "ShapeAndDtype objects passed to the graph_callable() "
                        "decorator (%d)." %
                        (func.__name__, func_num_args, len(func_inputs)))

      # First call the function to generate a graph which can initialize all
      # variables. As a side-effect this will populate the variable capturing
      # scope's view of which variables exist.
      variable_captures = _VariableCapturingScope()
      with variable_captures.initializing_scope(), function.capture_tensors(
          captures):
        func_outputs = func(*func_inputs)
      outputs_list = nest.flatten(func_outputs)
      if len(outputs_list) == 1 and outputs_list[0] is None:
        outputs_list = []
      output_shapes = [x.shape for x in outputs_list]
      if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list):
        raise ValueError("Found non-tensor output in %s" % str(outputs_list))
      initializing_operations = tmp_graph.get_operations()

      # Call the function again, now replacing usages of variables with
      # placeholders. This assumes the variable capturing scope created above
      # knows about all variables.
      with variable_captures.capturing_scope(), function.capture_tensors(
          captures):
        captured_outputs = func(*func_inputs)
      captured_outlist = nest.flatten(captured_outputs)
      capturing_operations = tmp_graph.get_operations()[
          len(initializing_operations):]

  sorted_variables = sorted(variable_captures.variables.values(),
                            key=lambda x: x.name)
  ids = list(sorted(captures.keys()))
  if ids:
    extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
  else:
    extra_inputs = []
    extra_placeholders = []

  flat_inputs = [x for x in nest.flatten(func_inputs)
                 if isinstance(x, tf_ops.Tensor)]
  placeholder_inputs = flat_inputs+ list(extra_placeholders)

  func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)]
  initializer_function_def = function.make_function_def(
      tmp_graph,
      initializing_operations,
      placeholder_inputs,
      func_def_outputs)
  # TODO(ashankar): Oh lord, forgive me for this lint travesty.
  # Also, what about the gradient registry of these functions? Those need to be
  # addressed as well.
  for f in tmp_graph._functions.values():  # pylint: disable=protected-access
    function._register_with_name(f.name, f.definition)  # pylint: disable=protected-access
  function._register_with_name(function._inference_name(func.__name__),  # pylint: disable=protected-access
                               initializer_function_def)
  initializer_function = function._GraphModeFunction(  # pylint: disable=protected-access
      placeholder_inputs,
      extra_inputs,
      initializer_function_def,
      tmp_graph,
      initializing_operations,
      func_outputs,
      function._map_sequence_obj_to_idx(func_def_outputs),  # pylint: disable=protected-access
      output_shapes)

  capture_func_def_outputs = [
      x for x in captured_outlist if isinstance(x, tf_ops.Tensor)]
  captured_function_def = function.make_function_def(
      tmp_graph,
      capturing_operations,
      placeholder_inputs,
      capture_func_def_outputs)
  function._register_with_name(function._inference_name(func.__name__),  # pylint: disable=protected-access
                               captured_function_def)
  captured_function = _FunctionObject(
      sorted_variables,
      placeholder_inputs,
      extra_inputs,
      captured_function_def,
      tmp_graph,
      capturing_operations,
      captured_outputs,
      function._map_sequence_obj_to_idx(capture_func_def_outputs),  # pylint: disable=protected-access
      output_shapes)

  return _InitializingFunctionObject(captured_function, initializer_function,
                                     shape_and_dtypes)
示例#2
0
def _graph_callable_internal(func, shape_and_dtypes):
    """Defines and returns a template version of func.

  Under the hood we make two function objects, each wrapping a different version
  of the graph-mode code. One version immediately runs variable initialization
  before making the variable's Tensors available for use, while the other
  version replaces the Variables with placeholders which become function
  arguments and get the current variable's value.

  Limitations in (2) and (4) are because this does not implement a graph-mode
  Variable class which has a convert_to_tensor(as_ref=True) method and a
  initialized_value method. This is fixable.

  Args:
    func: The tfe Python function to compile.
    shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects.

  Raises:
    ValueError: If any one of func's outputs is not a Tensor.

  Returns:
    Callable graph object.
  """
    container = tf_ops.get_default_graph()._container  # pylint: disable=protected-access
    container_prefix = tf_ops.get_default_graph()._container_prefix  # pylint: disable=protected-access
    with context.graph_mode():
        # This graph will store both the initialization and the call version of the
        # wrapped function. It will later be used by the backprop code to build the
        # backprop graph, if necessary.
        captures = {}
        tmp_graph = function.CapturingGraph(captures)
        # Inherit the container from the original graph to create resources at user
        # expected containers. Also inherits the container prefix, since this is
        # used for error checking when isolating Eager execution (the container
        # prefix at creation must match the container prefix when used, and
        # variables returned from the graph callable will be used in the outside
        # context).
        tmp_graph._container = container  # pylint: disable=protected-access
        tmp_graph._container_prefix = container_prefix  # pylint: disable=protected-access
        with tmp_graph.as_default():
            # Placeholders for the non-variable inputs.
            func_inputs = _get_graph_callable_inputs(shape_and_dtypes)
            func_num_args = len(tf_inspect.getargspec(func).args)
            if len(func_inputs) != func_num_args:
                raise TypeError(
                    "The number of arguments accepted by the decorated "
                    "function `%s` (%d) must match the number of "
                    "ShapeAndDtype objects passed to the graph_callable() "
                    "decorator (%d)." %
                    (func.__name__, func_num_args, len(func_inputs)))

            # First call the function to generate a graph which can initialize all
            # variables. As a side-effect this will populate the variable capturing
            # scope's view of which variables exist.
            variable_captures = _VariableCapturingScope()
            with variable_captures.initializing_scope(
            ), function.capture_tensors(captures):
                func_outputs = func(*func_inputs)
            outputs_list = nest.flatten(func_outputs)
            if len(outputs_list) == 1 and outputs_list[0] is None:
                outputs_list = []
            output_shapes = [x.shape for x in outputs_list]
            if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list):
                raise ValueError("Found non-tensor output in %s" %
                                 str(outputs_list))
            initializing_operations = tmp_graph.get_operations()

            # Call the function again, now replacing usages of variables with
            # placeholders. This assumes the variable capturing scope created above
            # knows about all variables.
            with variable_captures.capturing_scope(), function.capture_tensors(
                    captures):
                captured_outputs = func(*func_inputs)
            captured_outlist = nest.flatten(captured_outputs)
            capturing_operations = tmp_graph.get_operations(
            )[len(initializing_operations):]

    sorted_variables = sorted(variable_captures.variables.values(),
                              key=lambda x: x.name)
    ids = list(sorted(captures.keys()))
    if ids:
        extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
    else:
        extra_inputs = []
        extra_placeholders = []

    flat_inputs = [
        x for x in nest.flatten(func_inputs) if isinstance(x, tf_ops.Tensor)
    ]
    placeholder_inputs = flat_inputs + list(extra_placeholders)

    func_def_outputs = [
        x for x in outputs_list if isinstance(x, tf_ops.Tensor)
    ]
    initializer_function_def = function.make_function_def(
        tmp_graph, initializing_operations, placeholder_inputs,
        func_def_outputs)
    # TODO(ashankar): Oh lord, forgive me for this lint travesty.
    # Also, what about the gradient registry of these functions? Those need to be
    # addressed as well.
    for f in tmp_graph._functions.values():  # pylint: disable=protected-access
        function._register_with_name(f.name, f.definition)  # pylint: disable=protected-access
    function._register_with_name(
        function._inference_name(func.__name__),  # pylint: disable=protected-access
        initializer_function_def)
    initializer_function = function._GraphModeFunction(  # pylint: disable=protected-access
        placeholder_inputs,
        extra_inputs,
        initializer_function_def,
        tmp_graph,
        initializing_operations,
        func_outputs,
        function._map_sequence_obj_to_idx(func_def_outputs),  # pylint: disable=protected-access
        output_shapes)

    capture_func_def_outputs = [
        x for x in captured_outlist if isinstance(x, tf_ops.Tensor)
    ]
    captured_function_def = function.make_function_def(
        tmp_graph, capturing_operations, placeholder_inputs,
        capture_func_def_outputs)
    function._register_with_name(
        function._inference_name(func.__name__),  # pylint: disable=protected-access
        captured_function_def)
    captured_function = _FunctionObject(
        sorted_variables,
        placeholder_inputs,
        extra_inputs,
        captured_function_def,
        tmp_graph,
        capturing_operations,
        captured_outputs,
        function._map_sequence_obj_to_idx(capture_func_def_outputs),  # pylint: disable=protected-access
        output_shapes)

    return _InitializingFunctionObject(captured_function, initializer_function,
                                       shape_and_dtypes)