コード例 #1
0
    def __init__(self, target, parent_graph=None):
        """Initializes an ImperativeMode.

    Args:
      target: The TensorFlow execution engine to connect to.
      parent_graph: (Optional) An ImperativeGraph.

    Raises:
      UnimplementedError: if non-None parent_graph is not an ImperativeGraph.
    """
        self._target = target
        self._parent_graph = parent_graph
        # Create a new graph
        self._graph = imperative_graph.ImperativeGraph(
            parent_graph=self._parent_graph)
        self._default_graph = self._graph.as_default()
        # Context manager to record variable inits
        self._record_variable_inits = self._graph.record_variable_inits()
        if self._parent_graph:
            if not isinstance(self._parent_graph,
                              imperative_graph.ImperativeGraph):
                raise errors.UnimplementedError(
                    None, None, 'ImperativeMode needs an '
                    'ImperativeGraph')
            # Clone the `_parent_graph` in to the current graph. This is so that
            # operations used from the enclosing ImperativeMode context are
            # available in the current context.
            with self._graph.as_default(), self._graph.return_as_is():
                importer.import_graph_def(self._parent_graph.as_graph_def(),
                                          name='')
        self._session = session.Session(graph=self._graph, target=self._target)
        # Override the `_session`'s run, so that variable inits can be
        # called before the actual run.
        self._old_run = self._session.run
        self._session.run = self.run
        self._context_managers = [
            self._session.as_default(), self._default_graph,
            self._record_variable_inits,
            imperative_graph.add_session_attr(ops.Tensor, self._session)
        ]
コード例 #2
0
 def _eval(self, tensor):
   """Returns the value in the tensor. Must be implemented in sub-classes."""
   raise errors.UnimplementedError(
       "The evaluation method should be implemented in sub-classes.")
コード例 #3
0
    def create_op(self, *args, **kwargs):
        """Creates an `Operation`.

    For operations of the following form

      orig_value = op(*args, **kwargs)

    this function constructs the following subgraph :

      v = Variable()
      if v is not initialized:
        orig_value = op(*args, **kwargs)
        v.assign(orig_value) # Initializes v
        return orig_value
      else:
        return v

    The above transformation is not performed and the original op is returned
    as is if any of the following is true:
    * `_return_as_is` flag is set to true.
    * op_type is listed in _PASS_THROUGH_OPS
    * op has no outputs.
    * One of the op's return value has a ref type.

    Args:
      *args: Arguments for create_op()
      **kwargs: Keyword arguments for create_op(). Refer to
        tensorflow.python.framework.ops.Graph.create_op() for the mandatory
        and optional arguments.

    Returns:
      An Operation.

    Raises:
      UnimplementedError: if output type is a reference and the op's type
        is not one of the supported types in `_REF_OPS_WHITELIST`.
    """
        op_type = kwargs['op_type'] if 'op_type' in kwargs else args[0]
        output_dtypes = kwargs['dtypes'] if 'dtypes' in kwargs else args[2]
        output_dtypes = [dtypes.as_dtype(d) for d in output_dtypes]

        if self._return_as_is or op_type in _PASS_THROUGH_OPS:
            return self._wrap(
                super(ImperativeGraph, self).create_op(*args, **kwargs))

        if not output_dtypes:
            return self._wrap(
                super(ImperativeGraph, self).create_op(*args, **kwargs))

        output_has_ref = any([dtype._is_ref_dtype for dtype in output_dtypes])  # pylint: disable=protected-access

        if output_has_ref:
            if op_type not in _REF_OPS_WHITELIST:
                raise errors.UnimplementedError(
                    None, None, op_type + ' op not supported in '
                    'imperative graph')

            ret = super(ImperativeGraph, self).create_op(*args, **kwargs)

            if self._in_variable_creation:
                if op_type == 'Assign':
                    self.add_pending_init(ret)

            return self._wrap(ret)

        with self.return_as_is():
            # Declares the variables to hold the output values of this op.
            op_output_var = [
                state_ops.variable_op_v2(tensor_shape.TensorShape(None),
                                         dtype,
                                         container=self._name)
                for dtype in output_dtypes
            ]
            # Ops to free the resources used by the temporary cache variables.
            # The following two ops are created for each cache variable,
            # having no control dependencies on any other ops :
            # var_handle_op ----> destroy_resource_op
            for dtype, v in zip(output_dtypes, op_output_var):
                with ops.control_dependencies(None):
                    self._variable_cleanup_ops += [
                        gen_resource_variable_ops.destroy_resource_op(
                            gen_resource_variable_ops.var_handle_op(
                                dtype,
                                tensor_shape.TensorShape(None),
                                container=self._name,
                                shared_name=v.op.name),
                            ignore_lookup_error=True)
                    ]

            # Create the conditional to run the original op only when the variable
            # corresponding to the first output is not initialized.
            inited = state_ops.is_variable_initialized(op_output_var[0])
            v_f, v_t = control_flow_ops.ref_switch(op_output_var[0], inited)
            # pylint: disable=protected-access
            v_f_op = gen_array_ops._ref_identity(v_f)
            v_t_op = gen_array_ops._ref_identity(v_t)
            # pylint: enable=protected-access

            with ops.control_dependencies([v_f_op.op]):
                # Create the original op
                orig_op = self._wrap(
                    super(ImperativeGraph, self).create_op(*args, **kwargs))
            shapes = [val.get_shape() for val in orig_op.outputs]

            controls = []
            for var, val in zip(op_output_var, orig_op.outputs):
                if (not val.get_shape().is_fully_defined()
                        or val.get_shape().num_elements() > 0):
                    assign_op = state_ops.assign(var,
                                                 val,
                                                 validate_shape=False)
                    assign_op.set_shape(val.get_shape())
                    controls.append(assign_op)

            values = []
            if len(controls) > 1:
                if control_flow_ops.IsSwitch(orig_op):
                    # pylint: disable=protected-access
                    controls = gen_control_flow_ops._ref_merge(controls)
                    # pylint: enable=protected-access
                else:
                    controls = control_flow_ops.tuple(controls)

            for var, val in zip(op_output_var, orig_op.outputs):
                with ops.control_dependencies(controls):
                    with self.colocate_with(v_f_op):
                        real_val = array_ops.identity(val)
                with ops.control_dependencies([v_t_op.op]):
                    with self.colocate_with(v_t_op):
                        stored_val = array_ops.identity(var)
                    stored_val.set_shape(val.get_shape())
                    real_val, _ = control_flow_ops.merge(
                        [real_val, stored_val])
                real_val.op.node_def.attr['_gradient_op_type'].CopyFrom(
                    attr_value_pb2.AttrValue(
                        s=compat.as_bytes(self._merge_op_type)))
                values.append(real_val)

            for i, _ in enumerate(shapes):
                values[i].set_shape(shapes[i])
            self._outputs_map[orig_op.name] = values
            try:
                self._gradient_function_map[
                    orig_op.name] = ops.get_gradient_function(orig_op)
            except (KeyError, LookupError):
                pass
            else:
                orig_op.node_def.attr['_gradient_op_type'].CopyFrom(
                    attr_value_pb2.AttrValue(
                        s=compat.as_bytes(self._imperative_op_type)))

            return MultiOutputOperation(values)