def _graph_callable_internal(func, shape_and_dtypes): """Defines and returns a template version of func. Under the hood we make two function objects, each wrapping a different version of the graph-mode code. One version immediately runs variable initialization before making the variable's Tensors available for use, while the other version replaces the Variables with placeholders which become function arguments and get the current variable's value. Limitations in (2) and (4) are because this does not implement a graph-mode Variable class which has a convert_to_tensor(as_ref=True) method and a initialized_value method. This is fixable. Args: func: The tfe Python function to compile. shape_and_dtypes: A list of type ShapeAndDtype. Raises: ValueError: If any one of func's outputs is not a Tensor. Returns: Callable graph object. """ container = tf_ops.get_default_graph()._container # pylint: disable=protected-access container_prefix = tf_ops.get_default_graph()._container_prefix # pylint: disable=protected-access with context.graph_mode(): # This graph will store both the initialization and the call version of the # wrapped function. It will later be used by the backprop code to build the # backprop graph, if necessary. tmp_graph = tf_ops.Graph() # Inherit the container from the original graph to create resources at user # expected containers. Also inherits the container prefix, since this is # used for error checking when isolating Eager execution (the container # prefix at creation must match the container prefix when used, and # variables returned from the graph callable will be used in the outside # context). tmp_graph._container = container # pylint: disable=protected-access tmp_graph._container_prefix = container_prefix # pylint: disable=protected-access with tmp_graph.as_default(): # Placeholders for the non-variable inputs. func_inputs = _get_graph_callable_inputs(shape_and_dtypes) func_num_args = len(tf_inspect.getargspec(func).args) if len(func_inputs) != func_num_args: raise TypeError( "The number of arguments accepted by the decorated " "function `%s` (%d) must match the number of " "ShapeAndDtype objects passed to the graph_callable() " "decorator (%d)." % (func.__name__, func_num_args, len(func_inputs))) # First call the function to generate a graph which can initialize all # variables. As a side-effect this will populate the variable capturing # scope's view of which variables exist. variable_captures = _VariableCapturingScope() captures = {} with variable_captures.initializing_scope( ), function.capture_tensors(captures): func_outputs = func(*func_inputs) outputs_list = nest.flatten(func_outputs) if len(outputs_list) == 1 and outputs_list[0] is None: outputs_list = [] output_shapes = [x.shape for x in outputs_list] if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list): raise ValueError("Found non-tensor output in %s" % str(outputs_list)) initializing_operations = tmp_graph.get_operations() # Call the function again, now replacing usages of variables with # placeholders. This assumes the variable capturing scope created above # knows about all variables. with variable_captures.capturing_scope(), function.capture_tensors( captures): captured_outputs = func(*func_inputs) captured_outlist = nest.flatten(captured_outputs) capturing_operations = tmp_graph.get_operations( )[len(initializing_operations):] sorted_variables = sorted(variable_captures.variables.values(), key=lambda x: x.name) variable_placeholders = [x.placeholder for x in sorted_variables] ids = list(sorted(captures.keys())) if ids: extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids]) else: extra_inputs = [] extra_placeholders = [] flat_inputs = [ x for x in nest.flatten(func_inputs) if isinstance(x, tf_ops.Tensor) ] placeholder_inputs = flat_inputs + list(extra_placeholders) all_inputs = variable_placeholders + placeholder_inputs func_def_outputs = [ x for x in outputs_list if isinstance(x, tf_ops.Tensor) ] initializer_function_def = function.make_function_def( tmp_graph, initializing_operations, placeholder_inputs, func_def_outputs) # TODO(ashankar): Oh lord, forgive me for this lint travesty. # Also, what about the gradient registry of these functions? Those need to be # addressed as well. for f in tmp_graph._functions.values(): # pylint: disable=protected-access function._register_with_name(f.name, f.definition) # pylint: disable=protected-access function._register_with_name( function._inference_name(func.__name__), # pylint: disable=protected-access initializer_function_def) initializer_function = function._GraphModeFunction( # pylint: disable=protected-access placeholder_inputs, extra_inputs, initializer_function_def, tmp_graph, initializing_operations, func_outputs, function._map_sequence_obj_to_idx(func_def_outputs), # pylint: disable=protected-access output_shapes) capture_func_def_outputs = [ x for x in captured_outlist if isinstance(x, tf_ops.Tensor) ] captured_function_def = function.make_function_def( tmp_graph, capturing_operations, all_inputs, capture_func_def_outputs) function._register_with_name( function._inference_name(func.__name__), # pylint: disable=protected-access captured_function_def) captured_function = _FunctionObject( sorted_variables, all_inputs, extra_inputs, captured_function_def, tmp_graph, capturing_operations, captured_outputs, function._map_sequence_obj_to_idx(capture_func_def_outputs), # pylint: disable=protected-access output_shapes) return _InitializingFunctionObject(captured_function, initializer_function, shape_and_dtypes)
def _graph_callable_internal(func, shape_and_dtypes): """Defines and returns a template version of func. Under the hood we make two function objects, each wrapping a different version of the graph-mode code. One version immediately runs variable initialization before making the variable's Tensors available for use, while the other version replaces the Variables with placeholders which become function arguments and get the current variable's value. Limitations in (2) and (4) are because this does not implement a graph-mode Variable class which has a convert_to_tensor(as_ref=True) method and a initialized_value method. This is fixable. Args: func: The tfe Python function to compile. shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects. Raises: ValueError: If any one of func's outputs is not a Tensor. Returns: Callable graph object. """ container = tf_ops.get_default_graph()._container # pylint: disable=protected-access graph_key = tf_ops.get_default_graph()._graph_key # pylint: disable=protected-access with context.graph_mode(): # This graph will store both the initialization and the call version of the # wrapped function. It will later be used by the backprop code to build the # backprop graph, if necessary. captures = {} tmp_graph = function.CapturingGraph(captures) # Inherit the graph key from the original graph to ensure optimizers don't # misbehave. tmp_graph._container = container # pylint: disable=protected-access tmp_graph._graph_key = graph_key # pylint: disable=protected-access with tmp_graph.as_default(): # Placeholders for the non-variable inputs. func_inputs = _get_graph_callable_inputs(shape_and_dtypes) func_num_args = len(tf_inspect.getargspec(func).args) if len(func_inputs) != func_num_args: raise TypeError("The number of arguments accepted by the decorated " "function `%s` (%d) must match the number of " "ShapeAndDtype objects passed to the graph_callable() " "decorator (%d)." % (func.__name__, func_num_args, len(func_inputs))) # First call the function to generate a graph which can initialize all # variables. As a side-effect this will populate the variable capturing # scope's view of which variables exist. variable_captures = _VariableCapturingScope() with variable_captures.initializing_scope(), function.capture_tensors( captures), function.AutomaticControlDependencies() as a: func_outputs = func(*func_inputs) outputs_list = nest.flatten(func_outputs) for i, x in enumerate(outputs_list): if x is not None: outputs_list[i] = a.mark_as_return(x) if len(outputs_list) == 1 and outputs_list[0] is None: outputs_list = [] output_shapes = [x.shape for x in outputs_list] if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list): raise ValueError("Found non-tensor output in %s" % str(outputs_list)) initializing_operations = tmp_graph.get_operations() # Call the function again, now replacing usages of variables with # placeholders. This assumes the variable capturing scope created above # knows about all variables. tmp_graph.clear_resource_control_flow_state() with variable_captures.capturing_scope(), function.capture_tensors( captures), function.AutomaticControlDependencies() as a: captured_outputs = func(*func_inputs) captured_outlist = nest.flatten(captured_outputs) for i, x in enumerate(captured_outlist): if x is not None: captured_outlist[i] = a.mark_as_return(x) capturing_operations = tmp_graph.get_operations()[ len(initializing_operations):] sorted_variables = sorted(variable_captures.variables.values(), key=lambda x: x.name) ids = list(sorted(captures.keys())) if ids: extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids]) else: extra_inputs = [] extra_placeholders = [] flat_inputs = [x for x in nest.flatten(func_inputs) if isinstance(x, tf_ops.Tensor)] placeholder_inputs = flat_inputs+ list(extra_placeholders) func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)] initialization_name = function._inference_name(func.__name__) # pylint: disable=protected-access # TODO(ashankar): Oh lord, forgive me for this lint travesty. # Also, what about the gradient registry of these functions? Those need to be # addressed as well. for f in tmp_graph._functions.values(): # pylint: disable=protected-access function._register(f._c_func.func) # pylint: disable=protected-access initializer_function = function.GraphModeFunction( initialization_name, placeholder_inputs, extra_inputs, tmp_graph, initializing_operations, func_def_outputs, func_outputs, output_shapes) capture_func_def_outputs = [ x for x in captured_outlist if isinstance(x, tf_ops.Tensor)] captured_function_name = function._inference_name(func.__name__) # pylint: disable=protected-access captured_function = function.GraphModeFunction( captured_function_name, placeholder_inputs, extra_inputs, tmp_graph, capturing_operations, capture_func_def_outputs, captured_outputs, output_shapes, variables=[x.variable for x in sorted_variables]) return _InitializingFunctionObject(captured_function, initializer_function, shape_and_dtypes)
def _graph_callable_internal(func, shape_and_dtypes): """Defines and returns a template version of func. Under the hood we make two function objects, each wrapping a different version of the graph-mode code. One version immediately runs variable initialization before making the variable's Tensors available for use, while the other version replaces the Variables with placeholders which become function arguments and get the current variable's value. Limitations in (2) and (4) are because this does not implement a graph-mode Variable class which has a convert_to_tensor(as_ref=True) method and a initialized_value method. This is fixable. Args: func: The tfe Python function to compile. shape_and_dtypes: A list of type ShapeAndDtype. Raises: ValueError: If any one of func's outputs is not a Tensor. Returns: Callable graph object. """ container = tf_ops.get_default_graph()._container # pylint: disable=protected-access container_prefix = tf_ops.get_default_graph()._container_prefix # pylint: disable=protected-access with context.graph_mode(): # This graph will store both the initialization and the call version of the # wrapped function. It will later be used by the backprop code to build the # backprop graph, if necessary. tmp_graph = tf_ops.Graph() # Inherit the container from the original graph to create resources at user # expected containers. Also inherits the container prefix, since this is # used for error checking when isolating Eager execution (the container # prefix at creation must match the container prefix when used, and # variables returned from the graph callable will be used in the outside # context). tmp_graph._container = container # pylint: disable=protected-access tmp_graph._container_prefix = container_prefix # pylint: disable=protected-access with tmp_graph.as_default(): # Placeholders for the non-variable inputs. func_inputs = _get_graph_callable_inputs(shape_and_dtypes) func_num_args = len(tf_inspect.getargspec(func).args) if len(func_inputs) != func_num_args: raise TypeError("The number of arguments accepted by the decorated " "function `%s` (%d) must match the number of " "ShapeAndDtype objects passed to the graph_callable() " "decorator (%d)." % (func.__name__, func_num_args, len(func_inputs))) # First call the function to generate a graph which can initialize all # variables. As a side-effect this will populate the variable capturing # scope's view of which variables exist. variable_captures = _VariableCapturingScope() captures = {} with variable_captures.initializing_scope(), function.capture_tensors( captures): func_outputs = func(*func_inputs) outputs_list = nest.flatten(func_outputs) if len(outputs_list) == 1 and outputs_list[0] is None: outputs_list = [] output_shapes = [x.shape for x in outputs_list] if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list): raise ValueError("Found non-tensor output in %s" % str(outputs_list)) initializing_operations = tmp_graph.get_operations() # Call the function again, now replacing usages of variables with # placeholders. This assumes the variable capturing scope created above # knows about all variables. with variable_captures.capturing_scope(), function.capture_tensors( captures): captured_outputs = func(*func_inputs) captured_outlist = nest.flatten(captured_outputs) capturing_operations = tmp_graph.get_operations()[ len(initializing_operations):] sorted_variables = sorted(variable_captures.variables.values(), key=lambda x: x.name) variable_placeholders = [x.placeholder for x in sorted_variables] ids = list(sorted(captures.keys())) if ids: extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids]) else: extra_inputs = [] extra_placeholders = [] flat_inputs = [x for x in nest.flatten(func_inputs) if isinstance(x, tf_ops.Tensor)] placeholder_inputs = flat_inputs+ list(extra_placeholders) all_inputs = variable_placeholders + placeholder_inputs func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)] initializer_function_def = function.make_function_def( tmp_graph, initializing_operations, placeholder_inputs, func_def_outputs) # TODO(ashankar): Oh lord, forgive me for this lint travesty. # Also, what about the gradient registry of these functions? Those need to be # addressed as well. for f in tmp_graph._functions.values(): # pylint: disable=protected-access function._register_with_name(f.name, f.definition) # pylint: disable=protected-access function._register_with_name(function._inference_name(func.__name__), # pylint: disable=protected-access initializer_function_def) initializer_function = function._GraphModeFunction( # pylint: disable=protected-access placeholder_inputs, extra_inputs, initializer_function_def, tmp_graph, initializing_operations, func_outputs, function._map_sequence_obj_to_idx(func_def_outputs), # pylint: disable=protected-access output_shapes) capture_func_def_outputs = [ x for x in captured_outlist if isinstance(x, tf_ops.Tensor)] captured_function_def = function.make_function_def( tmp_graph, capturing_operations, all_inputs, capture_func_def_outputs) function._register_with_name(function._inference_name(func.__name__), # pylint: disable=protected-access captured_function_def) captured_function = _FunctionObject( sorted_variables, all_inputs, extra_inputs, captured_function_def, tmp_graph, capturing_operations, captured_outputs, function._map_sequence_obj_to_idx(capture_func_def_outputs), # pylint: disable=protected-access output_shapes) return _InitializingFunctionObject(captured_function, initializer_function, shape_and_dtypes)