Esempio n. 1
0
        def _make_callable(self, feed_arrays, feed_symbols, symbol_vals,
                           session):
            """
            Generates a callable that runs the graph.

            Arguments:
              feed_arrays: List of input tensors to be fed Numpy arrays at runtime.
              feed_symbols: List of input tensors to be fed symbolic tensors at runtime.
              symbol_vals: List of symbolic tensors to be fed to `feed_symbols`.
              session: Session to use to generate the callable.

            Returns:
              Function that runs the graph according to the above options.
            """
            # Prepare callable options.
            callable_opts = config_pb2.CallableOptions()
            # Handle external-data feed.
            for x in feed_arrays:
                callable_opts.feed.append(x.name)
            if self.feed_dict:
                for key in sorted(self.feed_dict.keys()):
                    callable_opts.feed.append(key.name)
            # Handle symbolic feed.
            for x, y in zip(feed_symbols, symbol_vals):
                connection = callable_opts.tensor_connection.add()
                if x.dtype != y.dtype:
                    y = math_ops.cast(y, x.dtype)
                from_tensor = ops._as_graph_element(y)
                if from_tensor is None:
                    from_tensor = y
                connection.from_tensor = from_tensor.name  # Data tensor
                connection.to_tensor = x.name  # Placeholder
            # Handle fetches.
            for x in self.outputs + self.fetches:
                callable_opts.fetch.append(x.name)
            # Handle updates.
            if self.updates_op:
                callable_opts.target.append(self.updates_op.name)
            # Handle run_options.
            if self.run_options:
                callable_opts.run_options.CopyFrom(self.run_options)
            # Create callable.
            callable_fn = session._make_callable_from_options(callable_opts)
            # Cache parameters corresponding to the generated callable, so that
            # we can detect future mismatches and refresh the callable.
            self._callable_fn = callable_fn
            self._feed_arrays = feed_arrays
            self._feed_symbols = feed_symbols
            self._symbol_vals = symbol_vals
            self._fetches = list(self.fetches)
            self._session = session
Esempio n. 2
0
    def control_dependencies(self, control_inputs):
        """Handles control dependencies.

    FuncGraph wraps Graph's control_dependencies logic by first filtering out
    any external tensors / operations and storing them in the graph's
    control_captures member. Any consumers of this function graph must then
    decide how to handle the control captures.

    Args:
      control_inputs: A list of `Operation` or `Tensor` objects which
        must be executed or computed before running the operations
        defined in the context.  Can also be `None` to clear the control
        dependencies.

    Returns:
     A context manager that specifies control dependencies for all
     operations constructed within the context.

    Raises:
      TypeError: If `control_inputs` is not a list of `Operation` or
        `Tensor` objects.
    """
        if control_inputs is None:
            return super(FuncGraph, self).control_dependencies(control_inputs)

        filtered_control_inputs = []
        for c in control_inputs:
            # Check for _UnreadVariable
            if (isinstance(c, ops.IndexedSlices)
                    or (hasattr(c, "_handle") and hasattr(c, "op"))):
                c = c.op
            graph_element = ops._as_graph_element(c)  # pylint: disable=protected-access
            if graph_element is None:
                graph_element = c
            if graph_element is not None and getattr(graph_element, "graph",
                                                     None) is not self:
                self.control_captures.add(graph_element)
            else:
                filtered_control_inputs.append(graph_element)
        return super(FuncGraph,
                     self).control_dependencies(filtered_control_inputs)
Esempio n. 3
0
  def control_dependencies(self, control_inputs):
    """Handles control dependencies.

    FuncGraph wraps Graph's control_dependencies logic by first filtering out
    any external tensors / operations and storing them in the graph's
    control_captures member. Any consumers of this function graph must then
    decide how to handle the control captures.

    Args:
      control_inputs: A list of `Operation` or `Tensor` objects which
        must be executed or computed before running the operations
        defined in the context.  Can also be `None` to clear the control
        dependencies.

    Returns:
     A context manager that specifies control dependencies for all
     operations constructed within the context.

    Raises:
      TypeError: If `control_inputs` is not a list of `Operation` or
        `Tensor` objects.
    """
    if control_inputs is None:
      return super(FuncGraph, self).control_dependencies(control_inputs)

    filtered_control_inputs = []
    for c in control_inputs:
      # Check for _UnreadVariable
      if (isinstance(c, ops.IndexedSlices) or
          (hasattr(c, "_handle") and hasattr(c, "op"))):
        c = c.op
      graph_element = ops._as_graph_element(c)  # pylint: disable=protected-access
      if graph_element is None:
        graph_element = c
      if graph_element is not None and getattr(
          graph_element, "graph", None) is not self:
        self.control_captures.add(graph_element)
      else:
        filtered_control_inputs.append(graph_element)
    return super(FuncGraph, self).control_dependencies(filtered_control_inputs)
Esempio n. 4
0
    def _make_callable(self, feed_arrays, feed_symbols, symbol_values, all_fetches):
        callable_opts = config_pb2.CallableOptions()
        for x in feed_arrays:
            callable_opts.feed.append(x.name)
        if self.feed_dict:
            for key in sorted(self.feed_dict.keys()):
                callable_opts.feed.appned(key.name)

        for x, y in zip(feed_symbols, symbol_values):
            connection = callable_opts.tensor_connection.add()
            if x.dtype != y.dtype:
                y = tf.cast(y, x.dtype)
            from_tensor = tf_ops._as_graph_element(y)
            if from_tensor is None:
                from_tensor = y
            connection.from_tensor = from_tensor.name
            connection.to_tensor = x.name

        self._all_fetches = all_fetches

        self._fetch_handler = _FetchHandler(
            graph=self.graph or tf.get_default_graph(),
            fetches=self._all_fetches, feeds={})
        for x in self._fetch_handler.fetches():
            callable_opts.fetch.append(x.name)

        callable_opts.target.append(self.updates_ops.name)

        if self.run_options:
            callable_opts.run_options.CopyFrom(self.run_options)
        callable_fn = self.tf_sess._make_callable_from_options(callable_opts)

        self._callable_fn = callable_fn
        self._feed_arrays = feed_arrays
        self._feed_symbols = feed_symbols
        self._symbol_values = symbol_values