示例#1
0
    def testCaptureOrdering(self):
        v1 = resource_variable_ops.ResourceVariable(1.0)
        v2 = resource_variable_ops.ResourceVariable(2.0)
        v3 = resource_variable_ops.ResourceVariable(3.0)

        @def_function.function
        def fn():
            return v1 + v2 + v3

        concrete_fn = fn.get_concrete_function()
        original_captures = concrete_fn.graph.captures
        outputs = concrete_fn.graph.outputs

        for _ in range(100):
            g = func_graph.FuncGraph('lifted')

            lift_to_graph.lift_to_graph(outputs,
                                        g,
                                        add_sources=True,
                                        handle_captures=True)
            lifted_captures = g.captures
            self.assertLen(lifted_captures, 3)
            for original_capture, lifted_capture in zip(
                    original_captures.values(), lifted_captures.values()):
                self.assertEqual(original_capture.name, lifted_capture.name)
  def testClassAttrsRemoved(self):
    """Tests that _class attrs (from colocate_with()) are removed."""
    @def_function.function
    def fn():
      two = constant_op.constant(2.0, name='two')
      ten = constant_op.constant(10.0, name='ten')
      twenty = math_ops.multiply(two, ten, name='twenty')
      three = constant_op.constant(3.0, name='three')
      with framework_ops.colocate_with(twenty):
        thirty = math_ops.multiply(three, ten, name='thirty')
      return ten, twenty, thirty

    concrete_fn = fn.get_concrete_function()
    self.assertItemsEqual(  # Before lifting, 'fn' has colocation attrs.
        concrete_fn.graph.get_operation_by_name('thirty').colocation_groups(),
        [compat.as_bytes('loc:@twenty')])
    thirty_out = concrete_fn.graph.outputs[2]

    g = func_graph.FuncGraph('lifted')
    lift_to_graph.lift_to_graph([thirty_out], g)

    # After lifting, colocation attrs are gone.
    ops = g.get_operations()
    self.assertItemsEqual([op.name for op in ops],
                          ['three', 'ten', 'thirty',  # Lifted from `fn` body.
                           thirty_out.op.name])  # Wrapper for output.
    for op in ops:
      with self.assertRaises(ValueError):
        class_attr = op.get_attr('_class')  # Expected not to exist.
        print('Unexpected class_attr', class_attr, 'on', op.name)
      self.assertItemsEqual(op.colocation_groups(),  # Expect default self-ref.
                            [compat.as_bytes('loc:@%s' % op.name)])
示例#3
0
    def initialize_variables():
      op_map = object_identity.ObjectIdentityDictionary()
      # Stack all the var_is_initialized values into one tensor and intepret the
      # numpy value. This will reduce the number of RPCs between client and
      # worker in the remote case.
      with ops.init_scope():
        var_is_initialized = []
        for v, _ in initializers:
          var_is_initialized.append(
              resource_variable_ops.var_is_initialized_op(v.handle))
        var_is_initialized = array_ops.stack(var_is_initialized).numpy()

      inits = []
      for (v, init), is_initialized in zip(initializers, var_is_initialized):
        with ops.init_scope():
          if is_initialized:
            continue
        inits.append(init)

      if inits:
        op_map = lift_to_graph.lift_to_graph(
            inits, ops.get_default_graph(), op_map=op_map)
      for (v, init), is_initialized in zip(initializers, var_is_initialized):
        with ops.init_scope():
          if is_initialized:
            continue
        v.assign(op_map[init], read_value=False)
 def prune(self, feeds, fetches):
     flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
     for f in flat_feeds + flat_fetches:
         if not isinstance(f, ops.Tensor):
             raise ValueError("Feeds and fetches must be tensors.")
         if f.graph is not self._func_graph:
             raise ValueError(
                 "Can only prune function whose feeds and fetches "
                 "are from this graph (%s). Tensor %s from graph %s" %
                 (self._func_graph, f, f.graph))
     with self._func_graph.as_default():
         pruned_graph = func_graph.FuncGraph("pruned")
         sink_tensor = array_ops.identity_n(flat_fetches)[0]
     lift_map = lift_to_graph.lift_to_graph(sink_tensor,
                                            pruned_graph,
                                            sources=flat_feeds +
                                            self.graph.internal_captures)
     pruned_graph.outputs.extend(lift_map[x] for x in flat_fetches)
     for external_capture, internal_capture in self.graph.captures.items():
         pruned_graph.captures[external_capture] = lift_map[
             internal_capture]
     pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
     pruned_graph.inputs.extend(pruned_graph.captures.values())
     pruned_graph.structured_outputs = nest.map_structure(
         lambda node: lift_map[node], fetches)
     pruned_fn = WrappedFunction(pruned_graph,
                                 variable_holder=self._variable_holder)
     pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
     pruned_fn._arg_keywords = []  # pylint: disable=protected-access
     return pruned_fn
示例#5
0
    def prune(self, feeds, fetches, name=None, input_signature=None):
        # TODO(b/129646028): Add support for CompositeTensors.
        name = name or "pruned"
        feeds = nest.map_structure(self.graph.as_graph_element, feeds)
        fetches = nest.map_structure(self.graph.as_graph_element, fetches)
        flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
        for f in flat_feeds:
            if not isinstance(f, ops.Tensor):
                raise ValueError("Feeds must be tensors.")

        # Ignoring all feeds that are captures allows prune to be called
        # using wrapped_func.inputs even when it uses variables
        internal_captures = self.graph.internal_captures
        flat_feeds = [f for f in flat_feeds if f not in internal_captures]

        operation_fetches = []
        for f in flat_fetches:
            if isinstance(f, ops.Operation):
                operation_fetches.append(f)
            elif not isinstance(f, ops.Tensor):
                raise ValueError("Fetches must be tensors or operations.")
        for f in flat_feeds + flat_fetches:
            if f.graph is not self._func_graph:
                raise ValueError(
                    "Can only prune function whose feeds and fetches "
                    "are from this graph (%s). Tensor %s from graph %s" %
                    (self._func_graph, f, f.graph))
        with self._func_graph.as_default():
            pruned_graph = func_graph.FuncGraph(name)
        lift_map = lift_to_graph.lift_to_graph(flat_fetches,
                                               pruned_graph,
                                               sources=flat_feeds +
                                               internal_captures)
        pruned_graph.outputs.extend(lift_map[x] for x in flat_fetches
                                    if isinstance(x, ops.Tensor))
        pruned_graph.control_outputs.extend(
            [lift_map[operation] for operation in operation_fetches])
        for external_capture, internal_capture in self.graph.captures.items():
            pruned_graph.captures[external_capture] = lift_map[
                internal_capture]
        pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
        pruned_graph.inputs.extend(pruned_graph.captures.values())

        pruned_graph.variables = self.graph.variables

        def _structured_output_mapping(fetched):
            lifted = lift_map[fetched]
            if isinstance(lifted, ops.Operation):
                return None
            return lifted

        pruned_graph.structured_outputs = nest.map_structure(
            _structured_output_mapping, fetches)
        pruned_graph.structured_input_signature = input_signature
        pruned_fn = WrappedFunction(pruned_graph,
                                    variable_holder=self._variable_holder)
        pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
        # TODO(kathywu): Enable keyword arguments if an input signature is specified
        pruned_fn._arg_keywords = [tensor.op.name for tensor in flat_feeds]  # pylint: disable=protected-access
        return pruned_fn
示例#6
0
 def initialize_variables():
   for v, init in initializer_map.items():
     with ops.init_scope():
       if resource_variable_ops.var_is_initialized_op(v.handle):
         # Ignore variables which are already initialized at trace time.
         continue
     v.assign(lift_to_graph.lift_to_graph(
         [init], ops.get_default_graph())[init])
示例#7
0
 def initialize_variables():
   for v, init in initializer_map.items():
     with ops.init_scope():
       if resource_variable_ops.var_is_initialized_op(v.handle):
         # Ignore variables which are already initialized at trace time.
         continue
     v.assign(lift_to_graph.lift_to_graph(
         [init], ops.get_default_graph())[init])
示例#8
0
    def prune(self, feeds, fetches):
        flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
        for f in flat_feeds:
            if not isinstance(f, ops.Tensor):
                raise ValueError("Feeds must be tensors.")
        tensor_fetches = []
        operation_fetches = []
        for f in flat_fetches:
            if isinstance(f, ops.Tensor):
                tensor_fetches.append(f)
            elif isinstance(f, ops.Operation):
                operation_fetches.append(f)
            else:
                raise ValueError("Fetches must be tensors or operations.")
        for f in flat_feeds + flat_fetches:
            if f.graph is not self._func_graph:
                raise ValueError(
                    "Can only prune function whose feeds and fetches "
                    "are from this graph (%s). Tensor %s from graph %s" %
                    (self._func_graph, f, f.graph))
        with self._func_graph.as_default():
            pruned_graph = func_graph.FuncGraph("pruned")
            with ops.control_dependencies(operation_fetches):
                if tensor_fetches:
                    identity_fetches = array_ops.identity_n(tensor_fetches)
                    sink_tensor = identity_fetches[0]
                else:
                    identity_fetches = []
                    sink_tensor = control_flow_ops.no_op()
        lift_map = lift_to_graph.lift_to_graph(sink_tensor,
                                               pruned_graph,
                                               sources=flat_feeds +
                                               self.graph.internal_captures)
        for original_fetch, identity_fetch in zip(tensor_fetches,
                                                  identity_fetches):
            lift_map[original_fetch] = lift_map[identity_fetch]
        pruned_graph.outputs.extend(lift_map[x] for x in flat_fetches
                                    if isinstance(x, ops.Tensor))
        for external_capture, internal_capture in self.graph.captures.items():
            pruned_graph.captures[external_capture] = lift_map[
                internal_capture]
        pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
        pruned_graph.inputs.extend(pruned_graph.captures.values())

        def _structured_output_mapping(fetched):
            lifted = lift_map[fetched]
            if isinstance(lifted, ops.Operation):
                return None
            return lifted

        pruned_graph.structured_outputs = nest.map_structure(
            _structured_output_mapping, fetches)
        pruned_fn = WrappedFunction(pruned_graph,
                                    variable_holder=self._variable_holder)
        pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
        pruned_fn._arg_keywords = []  # pylint: disable=protected-access
        return pruned_fn
示例#9
0
 def initialize_variables():
   op_map = object_identity.ObjectIdentityDictionary()
   for v, init in initializer_map.items():
     with ops.init_scope():
       if resource_variable_ops.var_is_initialized_op(v.handle):
         # Ignore variables which are already initialized at trace time.
         continue
     op_map = lift_to_graph.lift_to_graph(
         [init], ops.get_default_graph(), op_map=op_map)
     v.assign(op_map[init])
示例#10
0
  def prune(self, feeds, fetches, name=None, input_signature=None):
    # TODO(b/129646028): Add support for CompositeTensors.
    name = name or "pruned"
    flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
    for f in flat_feeds:
      if not isinstance(f, ops.Tensor):
        raise ValueError("Feeds must be tensors.")

    # Ignoring all feeds that are captures allows prune to be called
    # using wrapped_func.inputs even when it uses variables
    internal_captures = self.graph.internal_captures
    flat_feeds = [f for f in flat_feeds if f not in internal_captures]

    operation_fetches = []
    for f in flat_fetches:
      if isinstance(f, ops.Operation):
        operation_fetches.append(f)
      elif not isinstance(f, ops.Tensor):
        raise ValueError("Fetches must be tensors or operations.")
    for f in flat_feeds + flat_fetches:
      if f.graph is not self._func_graph:
        raise ValueError("Can only prune function whose feeds and fetches "
                         "are from this graph (%s). Tensor %s from graph %s" %
                         (self._func_graph, f, f.graph))
    with self._func_graph.as_default():
      pruned_graph = func_graph.FuncGraph(name)
    lift_map = lift_to_graph.lift_to_graph(
        flat_fetches, pruned_graph, sources=flat_feeds + internal_captures)
    pruned_graph.outputs.extend(
        lift_map[x] for x in flat_fetches if isinstance(x, ops.Tensor))
    pruned_graph.control_outputs.extend(
        [lift_map[operation] for operation in operation_fetches])
    for external_capture, internal_capture in self.graph.captures.items():
      pruned_graph.captures[external_capture] = lift_map[internal_capture]
    pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
    pruned_graph.inputs.extend(pruned_graph.captures.values())

    pruned_graph.variables = self.graph.variables

    def _structured_output_mapping(fetched):
      lifted = lift_map[fetched]
      if isinstance(lifted, ops.Operation):
        return None
      return lifted

    pruned_graph.structured_outputs = nest.map_structure(
        _structured_output_mapping, fetches)
    pruned_graph.structured_input_signature = input_signature
    pruned_fn = WrappedFunction(
        pruned_graph, variable_holder=self._variable_holder)
    pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
    # TODO(kathywu): Enable keyword arguments if an input signature is specified
    pruned_fn._arg_keywords = [tensor.op.name for tensor in flat_feeds]  # pylint: disable=protected-access
    return pruned_fn
示例#11
0
    def testClassAttrsRemoved(self):
        """Tests that _class attrs (from colocate_with()) are removed."""
        @def_function.function
        def fn():
            two = constant_op.constant(2.0, name='two')
            ten = constant_op.constant(10.0, name='ten')
            twenty = math_ops.multiply(two, ten, name='twenty')
            three = constant_op.constant(3.0, name='three')
            with framework_ops.colocate_with(twenty):
                thirty = math_ops.multiply(three, ten, name='thirty')
            return ten, twenty, thirty

        concrete_fn = fn.get_concrete_function()
        self.assertItemsEqual(  # Before lifting, 'fn' has colocation attrs.
            concrete_fn.graph.get_operation_by_name(
                'thirty').colocation_groups(),
            [compat.as_bytes('loc:@twenty')])
        thirty_out = concrete_fn.graph.outputs[2]

        g = func_graph.FuncGraph('lifted')
        lift_to_graph.lift_to_graph([thirty_out], g)

        # After lifting, colocation attrs are gone.
        ops = g.get_operations()
        self.assertItemsEqual(
            [op.name for op in ops],
            [
                'three',
                'ten',
                'thirty',  # Lifted from `fn` body.
                thirty_out.op.name
            ])  # Wrapper for output.
        for op in ops:
            with self.assertRaises(ValueError):
                class_attr = op.get_attr('_class')  # Expected not to exist.
                print('Unexpected class_attr', class_attr, 'on', op.name)
            self.assertItemsEqual(
                op.colocation_groups(),  # Expect default self-ref.
                [compat.as_bytes('loc:@%s' % op.name)])
  def testCaptureOrdering(self):
    v1 = resource_variable_ops.ResourceVariable(1.0)
    v2 = resource_variable_ops.ResourceVariable(2.0)
    v3 = resource_variable_ops.ResourceVariable(3.0)

    @def_function.function
    def fn():
      return v1 + v2 + v3

    concrete_fn = fn.get_concrete_function()
    original_captures = concrete_fn.graph.captures
    outputs = concrete_fn.graph.outputs

    for _ in range(100):
      g = func_graph.FuncGraph('lifted')

      lift_to_graph.lift_to_graph(
          outputs, g, add_sources=True, handle_captures=True)
      lifted_captures = g.captures
      self.assertLen(lifted_captures, 3)
      for original_capture, lifted_capture in zip(original_captures.values(),
                                                  lifted_captures.values()):
        self.assertEqual(original_capture.name, lifted_capture.name)
示例#13
0
 def prune(self, feeds, fetches):
   flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
   for f in flat_feeds + flat_fetches:
     if not isinstance(f, ops.Tensor):
       raise ValueError("Feeds and fetches must be tensors.")
     if f.graph is not self._func_graph:
       raise ValueError(
           "Can only prune function whose feeds and fetches "
           "are from this graph (%s). Tensor %s from graph %s" % (
               self._func_graph, f, f.graph))
   with self._func_graph.as_default():
     pruned_graph = func_graph.FuncGraph("pruned")
     sink_tensor = array_ops.identity_n(flat_fetches)[0]
   lift_map = lift_to_graph.lift_to_graph(
       sink_tensor, pruned_graph, sources=flat_feeds)
   pruned_graph.outputs.extend(lift_map[x] for x in flat_fetches)
   pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
   pruned_fn = WrappedFunction(
       pruned_graph, variable_holder=self._variable_holder)
   pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
   pruned_fn._arg_keywords = []  # pylint: disable=protected-access
   return pruned_fn
示例#14
0
    def __init__(
            self,  # pylint: disable=super-init-not-called
            initial_value=None,
            trainable=None,
            caching_device=None,
            name=None,
            dtype=None,
            constraint=None,
            add_initializers_to=None,
            lifted_initializer_graph=None,
            synchronization=None,
            aggregation=None,
            **unused_kwargs):
        """Creates a variable.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, GradientTapes automatically watch uses of this
        Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).
      constraint: An optional projection function to be applied to the variable
        after being updated by an `Optimizer` (e.g. used to implement norm
        constraints or value constraints for layer weights). The function must
        take as input the unprojected Tensor representing the value of the
        variable and return the Tensor for the projected value
        (which must have the same shape). Constraints are not safe to
        use when doing asynchronous distributed training.
      add_initializers_to: if not None and not in legacy graph mode, the
        initializer tensor will be added to this map in addition to adding the
        assignment to the function.
      lifted_initializer_graph: FuncGraph to try to lift initializers to.
      synchronization: Indicates when a distributed a variable will be
        aggregated. Accepted values are constants defined in the class
        `tf.VariableSynchronization`. By default the synchronization is set to
        `AUTO` and the current `DistributionStrategy` chooses
        when to synchronize. If `synchronization` is set to `ON_READ`,
        `trainable` must not be set to `True`.
      aggregation: Indicates how a distributed variable will be aggregated.
        Accepted values are constants defined in the class
        `tf.VariableAggregation`.

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
      RuntimeError: If called outside of a function definition.
    """
        if not ops.inside_function():
            # If we've been init_scope()d out of the function definition nothing to do
            # here; we can't really do the capturing or conditional logic.
            resource_variable_ops.ResourceVariable.__init__(
                self,
                initial_value=initial_value,
                trainable=trainable,
                caching_device=caching_device,
                name=name,
                dtype=dtype,
                constraint=constraint)
            return
        with ops.init_scope():
            self._in_graph_mode = not context.executing_eagerly()
        if initial_value is None:
            raise ValueError("initial_value must be specified.")
        init_from_fn = callable(initial_value)

        if constraint is not None and not callable(constraint):
            raise ValueError("The `constraint` argument must be a callable.")

        if isinstance(initial_value, trackable.CheckpointInitialValue):
            self._maybe_initialize_trackable()
            self._update_uid = initial_value.checkpoint_position.restore_uid
            initial_value = initial_value.wrapped_value

        synchronization, aggregation, trainable = (
            variables.validate_synchronization_aggregation_trainable(
                synchronization, aggregation, trainable, name))
        self._trainable = trainable
        self._synchronization = synchronization
        self._aggregation = aggregation
        self._save_slice_info = None
        self._initial_value = None
        self._initializer_op = None
        self._is_initialized_op = None
        self._graph_element = None
        self._cached_value = None
        # Store the graph key so optimizers know how to only retrieve variables from
        # this graph. Guaranteed to be the same as the eager graph_key.
        self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
        with ops.name_scope(name, "Variable",
                            [] if init_from_fn else [initial_value]) as name:
            # pylint: disable=protected-access
            with ops.init_scope():
                handle_name = ops.name_from_scope_name(name)
                unique_id = "%s_%d" % (handle_name, ops.uid())
                shared_name = context.shared_name(unique_id)
            with ops.name_scope("Initializer"), ops.device(None):
                initial_value = ops.convert_to_tensor(
                    initial_value() if init_from_fn else initial_value,
                    name="initial_value",
                    dtype=dtype)
            with ops.init_scope():
                self._handle = resource_variable_ops.eager_safe_variable_handle(
                    initial_value=initial_value,
                    shared_name=shared_name,
                    name=name,
                    graph_mode=self._in_graph_mode)
            self._shape = initial_value.shape
            self._unique_id = unique_id
            self._handle_name = handle_name + ":0"
            self._dtype = initial_value.dtype.base_dtype
            self._constraint = constraint
            assert initial_value is not None
            if self._in_graph_mode:
                with ops.init_scope():
                    outer_graph = ops.get_default_graph()
                func_graph = ops.get_default_graph()
                function_placeholders = (func_graph.inputs +
                                         func_graph.internal_captures)
                placeholder_ops = set(
                    [tensor.op for tensor in function_placeholders])
                lifted_initializer = lift_to_graph.lift_to_graph(
                    [initial_value],
                    outer_graph,
                    disallowed_placeholders=placeholder_ops)[initial_value]
                with ops.init_scope():
                    self._initial_value = lifted_initializer
                    with ops.name_scope("IsInitialized"):
                        self._is_initialized_op = (
                            resource_variable_ops.var_is_initialized_op(
                                self._handle))
                    if initial_value is not None:
                        with ops.name_scope("Assign") as n, ops.colocate_with(
                                self._handle):
                            self._initializer_op = resource_variable_ops.assign_variable_op(
                                self._handle, lifted_initializer, name=n)
                    with ops.name_scope("Read"), ops.colocate_with(
                            self._handle):
                        # Manually assign reads to the handle's device to avoid log
                        # messages.
                        with ops.device(self._handle.device):
                            value = self._read_variable_op()
                        self._graph_element = value
                    ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self)
            else:
                if add_initializers_to is not None:
                    add_initializers_to[self] = initial_value

                def assign_fn():
                    with ops.name_scope("Assign") as n, ops.colocate_with(
                            self._handle):
                        resource_variable_ops.assign_variable_op(self._handle,
                                                                 initial_value,
                                                                 name=n)
                        # Returning values to keep tf.cond happy.
                    return ops.convert_to_tensor(1)

                def not_assign_fn():
                    return ops.convert_to_tensor(0)

                # Note: this cond is always guaranteed to run because we're inside a
                # defun which will insert automatic control dependencies.
                control_flow_ops.cond(
                    resource_variable_ops.var_is_initialized_op(self._handle),
                    not_assign_fn, assign_fn)

        # After the handle has been created, set up a way to clean it up when
        # executing eagerly. We'll hold the only reference to the deleter, so that
        # when this object is garbage collected the deleter will be too. This
        # means ResourceVariables can be part of reference cycles without those
        # cycles being uncollectable.
        if not self._in_graph_mode:
            self._handle_deleter = resource_variable_ops.EagerResourceDeleter(
                handle=self._handle, handle_device=self._handle.device)
        self._cached_shape_as_list = None
示例#15
0
    def prune(self, feeds, fetches, name=None):
        name = name or "pruned"
        flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
        for f in flat_feeds:
            if not isinstance(f, ops.Tensor):
                raise ValueError("Feeds must be tensors.")

        # Ignoring all feeds that are captures allows prune to be called
        # using wrapped_func.inputs even when it uses variables
        internal_captures = self.graph.internal_captures
        flat_feeds = [f for f in flat_feeds if f not in internal_captures]

        tensor_fetches = []
        operation_fetches = []
        for f in flat_fetches:
            if isinstance(f, ops.Tensor):
                tensor_fetches.append(f)
            elif isinstance(f, ops.Operation):
                operation_fetches.append(f)
            else:
                raise ValueError("Fetches must be tensors or operations.")
        for f in flat_feeds + flat_fetches:
            if f.graph is not self._func_graph:
                raise ValueError(
                    "Can only prune function whose feeds and fetches "
                    "are from this graph (%s). Tensor %s from graph %s" %
                    (self._func_graph, f, f.graph))
        with self._func_graph.as_default():
            pruned_graph = func_graph.FuncGraph(name)
            with ops.control_dependencies(operation_fetches):
                if tensor_fetches:
                    identity_fetches = array_ops.identity_n(tensor_fetches)
                    sink_tensor = identity_fetches[0]
                else:
                    identity_fetches = []
                    sink_tensor = array_ops.zeros([])
        lift_map = lift_to_graph.lift_to_graph([sink_tensor],
                                               pruned_graph,
                                               sources=flat_feeds +
                                               internal_captures)
        for original_fetch, identity_fetch in zip(tensor_fetches,
                                                  identity_fetches):
            lift_map[original_fetch] = lift_map[identity_fetch]
        pruned_graph.outputs.extend(lift_map[x] for x in flat_fetches
                                    if isinstance(x, ops.Tensor))
        pruned_graph.control_outputs.extend(
            [lift_map[operation] for operation in operation_fetches])
        if not tensor_fetches:
            pruned_graph.outputs.append(lift_map[sink_tensor])
        for external_capture, internal_capture in self.graph.captures.items():
            pruned_graph.captures[external_capture] = lift_map[
                internal_capture]
        pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
        pruned_graph.inputs.extend(pruned_graph.captures.values())

        pruned_graph.variables = self.graph.variables

        def _structured_output_mapping(fetched):
            lifted = lift_map[fetched]
            if isinstance(lifted, ops.Operation):
                return None
            return lifted

        pruned_graph.structured_outputs = nest.map_structure(
            _structured_output_mapping, fetches)
        pruned_fn = WrappedFunction(pruned_graph,
                                    variable_holder=self._variable_holder)
        pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
        pruned_fn._arg_keywords = []  # pylint: disable=protected-access
        return pruned_fn
示例#16
0
 def initialize_variables():
   for v, init in initializer_map.items():
     v.assign(lift_to_graph.lift_to_graph(
         [init], ops.get_default_graph())[init])
示例#17
0
  def __init__(self,  # pylint: disable=super-init-not-called
               initial_value=None,
               trainable=None,
               caching_device=None,
               name=None,
               dtype=None,
               constraint=None,
               add_initializers_to=None,
               lifted_initializer_graph=None,
               **unused_kwargs):
    """Creates a variable.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, GradientTapes automatically watch uses of this
        Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).
      constraint: An optional projection function to be applied to the variable
        after being updated by an `Optimizer` (e.g. used to implement norm
        constraints or value constraints for layer weights). The function must
        take as input the unprojected Tensor representing the value of the
        variable and return the Tensor for the projected value
        (which must have the same shape). Constraints are not safe to
        use when doing asynchronous distributed training.
      add_initializers_to: if not None and not in legacy graph mode, the
        initializer tensor will be added to this map in addition to adding the
        assignment to the function.
      lifted_initializer_graph: FuncGraph to try to lift initializers to.

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
      RuntimeError: If called outside of a function definition.
    """
    if not ops.inside_function():
      # If we've been init_scope()d out of the function definition nothing to do
      # here; we can't really do the capturing or conditional logic.
      resource_variable_ops.ResourceVariable.__init__(
          self, initial_value=initial_value, trainable=trainable,
          caching_device=caching_device, name=name, dtype=dtype,
          constraint=constraint)
      return
    with ops.init_scope():
      self._in_graph_mode = not context.executing_eagerly()
    if initial_value is None:
      raise ValueError("initial_value must be specified.")
    init_from_fn = callable(initial_value)

    if constraint is not None and not callable(constraint):
      raise ValueError("The `constraint` argument must be a callable.")

    if isinstance(initial_value, trackable.CheckpointInitialValue):
      self._maybe_initialize_trackable()
      self._update_uid = initial_value.checkpoint_position.restore_uid
      initial_value = initial_value.wrapped_value

    if trainable is None:
      trainable = True
    self._trainable = trainable
    self._save_slice_info = None
    self._initial_value = None
    self._initializer_op = None
    self._is_initialized_op = None
    self._graph_element = None
    self._cached_value = None
    # Store the graph key so optimizers know how to only retrieve variables from
    # this graph. Guaranteed to be the same as the eager graph_key.
    self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
    with ops.name_scope(name, "Variable", []
                        if init_from_fn else [initial_value]) as name:
      # pylint: disable=protected-access
      with ops.init_scope():
        handle_name = ops._name_from_scope_name(name)
        unique_id = "%s_%d" % (handle_name, ops.uid())
        shared_name = context.shared_name(unique_id)
      with ops.name_scope("Initializer"), ops.device(None):
        initial_value = ops.convert_to_tensor(
            initial_value() if init_from_fn else initial_value,
            name="initial_value", dtype=dtype)
      with ops.init_scope():
        self._handle = resource_variable_ops.eager_safe_variable_handle(
            initial_value=initial_value,
            shared_name=shared_name,
            name=name,
            graph_mode=self._in_graph_mode)
      self._shape = initial_value.shape
      self._unique_id = unique_id
      self._handle_name = handle_name + ":0"
      self._dtype = initial_value.dtype.base_dtype
      self._constraint = constraint
      assert initial_value is not None
      if self._in_graph_mode:
        with ops.init_scope():
          outer_graph = ops.get_default_graph()
        func_graph = ops.get_default_graph()
        function_placeholders = (
            func_graph.inputs + func_graph.internal_captures)
        placeholder_ops = set(
            [tensor.op for tensor in function_placeholders])
        lifted_initializer = lift_to_graph.lift_to_graph(
            [initial_value], outer_graph,
            disallowed_placeholders=placeholder_ops)[initial_value]
        with ops.init_scope():
          self._initial_value = lifted_initializer
          with ops.name_scope("IsInitialized"):
            self._is_initialized_op = (
                resource_variable_ops.var_is_initialized_op(self._handle))
          if initial_value is not None:
            with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
              self._initializer_op = resource_variable_ops.assign_variable_op(
                  self._handle, lifted_initializer, name=n)
          with ops.name_scope("Read"), ops.colocate_with(self._handle):
            # Manually assign reads to the handle's device to avoid log
            # messages.
            with ops.device(self._handle.device):
              value = self._read_variable_op()
            self._graph_element = value
          ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self)
      else:
        if add_initializers_to is not None:
          add_initializers_to[self] = initial_value
        def assign_fn():
          with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
            resource_variable_ops.assign_variable_op(
                self._handle,
                initial_value,
                name=n)
            # Returning values to keep tf.cond happy.
          return ops.convert_to_tensor(1)
        def not_assign_fn():
          return ops.convert_to_tensor(0)
        # Note: this cond is always guaranteed to run because we're inside a
        # defun which will insert automatic control dependencies.
        control_flow_ops.cond(
            resource_variable_ops.var_is_initialized_op(self._handle),
            not_assign_fn, assign_fn)

    # After the handle has been created, set up a way to clean it up when
    # executing eagerly. We'll hold the only reference to the deleter, so that
    # when this object is garbage collected the deleter will be too. This
    # means ResourceVariables can be part of reference cycles without those
    # cycles being uncollectable.
    if not self._in_graph_mode:
      self._handle_deleter = resource_variable_ops.EagerResourceDeleter(
          handle=self._handle, handle_device=self._handle.device)
    self._cached_shape_as_list = None
示例#18
0
  def prune(self, feeds, fetches):
    flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
    for f in flat_feeds:
      if not isinstance(f, ops.Tensor):
        raise ValueError("Feeds must be tensors.")

    # Ignoring all feeds that are captures allows prune to be called
    # using wrapped_func.inputs even when it uses variables
    internal_captures = self.graph.internal_captures
    flat_feeds = [f for f in flat_feeds
                  if f not in internal_captures]

    tensor_fetches = []
    operation_fetches = []
    for f in flat_fetches:
      if isinstance(f, ops.Tensor):
        tensor_fetches.append(f)
      elif isinstance(f, ops.Operation):
        operation_fetches.append(f)
      else:
        raise ValueError("Fetches must be tensors or operations.")
    for f in flat_feeds + flat_fetches:
      if f.graph is not self._func_graph:
        raise ValueError(
            "Can only prune function whose feeds and fetches "
            "are from this graph (%s). Tensor %s from graph %s" % (
                self._func_graph, f, f.graph))
    with self._func_graph.as_default():
      pruned_graph = func_graph.FuncGraph("pruned")
      with ops.control_dependencies(operation_fetches):
        if tensor_fetches:
          identity_fetches = array_ops.identity_n(tensor_fetches)
          sink_tensor = identity_fetches[0]
        else:
          identity_fetches = []
          sink_tensor = array_ops.zeros([])
    lift_map = lift_to_graph.lift_to_graph(
        [sink_tensor], pruned_graph, sources=flat_feeds + internal_captures)
    for original_fetch, identity_fetch in zip(
        tensor_fetches, identity_fetches):
      lift_map[original_fetch] = lift_map[identity_fetch]
    pruned_graph.outputs.extend(
        lift_map[x] for x in flat_fetches if isinstance(x, ops.Tensor))
    if not tensor_fetches:
      pruned_graph.outputs.append(lift_map[sink_tensor])
    for external_capture, internal_capture in self.graph.captures.items():
      pruned_graph.captures[external_capture] = lift_map[internal_capture]
    pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
    pruned_graph.inputs.extend(pruned_graph.captures.values())

    pruned_graph.variables = self.graph.variables

    def _structured_output_mapping(fetched):
      lifted = lift_map[fetched]
      if isinstance(lifted, ops.Operation):
        return None
      return lifted

    pruned_graph.structured_outputs = nest.map_structure(
        _structured_output_mapping, fetches)
    pruned_fn = WrappedFunction(
        pruned_graph, variable_holder=self._variable_holder)
    pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
    pruned_fn._arg_keywords = []  # pylint: disable=protected-access
    return pruned_fn
示例#19
0
  def __init__(self,
               initial_value=None,
               trainable=None,
               caching_device=None,
               name=None,
               dtype=None,
               constraint=None,
               add_initializers_to=None,
               lifted_initializer_graph=None,
               synchronization=None,
               aggregation=None,
               shape=None,
               **unused_kwargs):
    """Creates a variable.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, GradientTapes automatically watch uses of this
        Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).
      constraint: An optional projection function to be applied to the variable
        after being updated by an `Optimizer` (e.g. used to implement norm
        constraints or value constraints for layer weights). The function must
        take as input the unprojected Tensor representing the value of the
        variable and return the Tensor for the projected value
        (which must have the same shape). Constraints are not safe to
        use when doing asynchronous distributed training.
      add_initializers_to: if not None and not in legacy graph mode, the
        initializer tensor will be added to this map in addition to adding the
        assignment to the function.
      lifted_initializer_graph: FuncGraph to try to lift initializers to.
      synchronization: Indicates when a distributed a variable will be
        aggregated. Accepted values are constants defined in the class
        `tf.VariableSynchronization`. By default the synchronization is set to
        `AUTO` and the current `DistributionStrategy` chooses
        when to synchronize. If `synchronization` is set to `ON_READ`,
        `trainable` must not be set to `True`.
      aggregation: Indicates how a distributed variable will be aggregated.
        Accepted values are constants defined in the class
        `tf.VariableAggregation`.
      shape: (optional) The shape of this variable. If None, the shape of
        `initial_value` will be used. When setting this argument to
        `tf.TensorShape(None)` (representing an unspecified shape), the variable
        can be assigned with values of different shapes.

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
      RuntimeError: If called outside of a function definition.
    """
    if not ops.inside_function():
      # If we've been init_scope()d out of the function definition nothing to do
      # here; we can't really do the capturing or conditional logic.
      resource_variable_ops.ResourceVariable.__init__(
          self, initial_value=initial_value, trainable=trainable,
          caching_device=caching_device, name=name, dtype=dtype,
          constraint=constraint)
      return
    if initial_value is None:
      raise ValueError("initial_value must be specified.")
    init_from_fn = callable(initial_value)

    if constraint is not None and not callable(constraint):
      raise ValueError("The `constraint` argument must be a callable.")

    if isinstance(initial_value, trackable.CheckpointInitialValue):
      self._maybe_initialize_trackable()
      self._update_uid = initial_value.checkpoint_position.restore_uid
      initial_value = initial_value.wrapped_value

    with ops.name_scope(name, "Variable", []
                        if init_from_fn else [initial_value]) as name:
      with ops.name_scope("Initializer"), ops.device(None):
        initial_value = ops.convert_to_tensor(
            initial_value() if init_from_fn else initial_value,
            name="initial_value", dtype=dtype)
      assert initial_value is not None

      # Don't use `shape or initial_value.shape` since TensorShape has
      # overridden `__bool__`.
      if shape is None:
        shape = initial_value.shape

      # Use the constructor for UninitializedVariable to start.
      super(UnliftedInitializerVariable, self).__init__(
          trainable=trainable,
          caching_device=caching_device,
          name=name,
          shape=shape,
          dtype=initial_value.dtype,
          constraint=constraint,
          synchronization=synchronization,
          aggregation=aggregation,
          extra_handle_data=initial_value,
          **unused_kwargs)

      if self._in_graph_mode:
        with ops.init_scope():
          outer_graph = ops.get_default_graph()
        func_graph = ops.get_default_graph()
        function_placeholders = (
            func_graph.inputs + func_graph.internal_captures)
        placeholder_ops = set(
            [tensor.op for tensor in function_placeholders])
        lifted_initializer = lift_to_graph.lift_to_graph(
            [initial_value], outer_graph,
            disallowed_placeholders=placeholder_ops)[initial_value]
        with ops.init_scope():
          self._initial_value = lifted_initializer
          with ops.name_scope("IsInitialized"):
            self._is_initialized_op = (
                resource_variable_ops.var_is_initialized_op(self._handle))
          if initial_value is not None:
            with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
              self._initializer_op = resource_variable_ops.assign_variable_op(
                  self._handle, lifted_initializer, name=n)
      else:
        if add_initializers_to is not None:
          add_initializers_to[self] = initial_value
        def assign_fn():
          with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
            resource_variable_ops.assign_variable_op(
                self._handle,
                initial_value,
                name=n)
            # Returning values to keep tf.cond happy.
          return ops.convert_to_tensor(1)
        def not_assign_fn():
          return ops.convert_to_tensor(0)
        # Note: this cond is always guaranteed to run because we're inside a
        # defun which will insert automatic control dependencies.
        control_flow_ops.cond(
            resource_variable_ops.var_is_initialized_op(self._handle),
            not_assign_fn, assign_fn)
示例#20
0
    def prune(self, feeds, fetches, name=None, input_signature=None):
        """Extract a subgraph of this function's underlying graph.

    Wraps the subgraph in a new `WrappedFunction` object.

    Args:
      feeds: Input tensors to the subgraph to extract, as `Tensor` objects.
      fetches: Possibly-nested Python data structure containing information
        about outputs of the target subgraph. Each entry can either be a
        `Tensor` object (for data outputs), an `Operation` object (for control
        outputs), or a `TensorInfo` proto. Any additional shape/dtype
        information provided in a `TensorInfo` and not present in the original
        graph will be added to the returned subgraph.
      name: (optional) Name to give to the underlying `FuncGraph` of the
        returned object. If no name is provided, the graph's name will be
        `"pruned"`.
      input_signature: (optional) possibly-nested Python data structure
        containing `TensorSpec` objects, with which to populate the returned
        functions's `FuncGraph`'s `structured_input_signature` field.

    Returns:
      A new `WrappedFunction` object containing a copy of the portion of this
        object's graph that goes from `feeds` to `fetches`.
    """
        # TODO(b/129646028): Add support for CompositeTensors.
        name = name or "pruned"
        flat_feeds = nest.flatten(feeds, expand_composites=True)
        flat_feeds = [self.graph.as_graph_element(t) for t in flat_feeds]
        for f in flat_feeds:
            if not isinstance(f, ops.Tensor):
                raise ValueError("Feeds must be tensors.")

        # Ignoring all feeds that are captures allows prune to be called
        # using wrapped_func.inputs even when it uses variables
        internal_captures = object_identity.ObjectIdentitySet(
            self.graph.internal_captures)
        flat_feeds = [f for f in flat_feeds if f not in internal_captures]

        operation_fetches = []
        tensor_fetches = []
        tensor_infos = []

        def _fetch_preprocesing_callback(fetch):
            """Extract out lists of ops, tensors, and tensor type info.

      Turns TensorInfos into Tensors in the original `fetches` structure.
      Also extracts ops from `fetches`.

      Args:
        fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or
          string identifying a Tensor or Operation.

      Returns:
        `fetch` converted to a Tensor.
      """
            if isinstance(fetch, ops.Operation):
                operation_fetches.append(fetch)
                return fetch
            elif isinstance(fetch, meta_graph_pb2.TensorInfo):
                tensor_infos.append(fetch)
                decoded = _get_element_from_tensor_info(
                    fetch, self._func_graph)
                if (tensor_util.is_tensor(decoded) or isinstance(
                        decoded, composite_tensor.CompositeTensor)):
                    tensor_fetches.append(decoded)
                else:
                    operation_fetches.append(decoded)
                return decoded
            elif isinstance(fetch,
                            (ops.Tensor, composite_tensor.CompositeTensor)):
                tensor_fetches.append(fetch)
                return fetch
            else:
                graph_element = self.graph.as_graph_element(fetch)
                return _fetch_preprocesing_callback(graph_element)

        fetches = nest.map_structure(_fetch_preprocesing_callback, fetches)

        # Expand composite tensors into their component dense Tensors.
        tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True)

        for f in (flat_feeds + tensor_fetches + operation_fetches):
            if f.graph is not self._func_graph:
                raise ValueError(
                    "Can only prune function whose feeds and fetches "
                    "are from this graph (%s). Input %s is from graph %s" %
                    (self._func_graph, f, f.graph))
        with self._func_graph.as_default():
            pruned_graph = func_graph.FuncGraph(name)
        lift_map = lift_to_graph.lift_to_graph(
            operation_fetches + tensor_fetches,
            pruned_graph,
            sources=flat_feeds + self.graph.internal_captures)

        # Note that we add the component tensors of any composite tensors to the
        # returned function's outputs list; the list must contain these component
        # tensors, or the function's sparse outputs won't work properly.
        pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
        pruned_graph.control_outputs.extend(
            [lift_map[operation] for operation in operation_fetches])
        pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
        for external_capture, internal_capture in self.graph.captures:
            pruned_graph.add_capture(external_capture,
                                     lift_map[internal_capture])
        for ti in tensor_infos:
            if ti.WhichOneof("encoding") == "name":  # Dense tensors only
                t = pruned_graph.as_graph_element(ti.name)
                if tensor_util.is_tensor(t):
                    t.set_shape(tensor_shape.TensorShape(ti.tensor_shape))
        # pylint: disable=protected-access
        for f in self.graph._functions.values():
            pruned_graph._add_function(f)
        # pylint: enable=protected-access

        pruned_graph.variables = self.graph.variables

        def _structured_output_mapping(fetched):
            """callback for `nest.map_structure()`"""
            lifted = lift_map[fetched]
            if isinstance(lifted, ops.Operation):
                return None
            return lifted

        # expand_composites=True here causes composite tensors to be expanded
        # into their component dense Tensors, mapped to the new graph, and then
        # reconstituted into their original composite form.
        pruned_graph.structured_outputs = nest.map_structure(
            _structured_output_mapping, fetches, expand_composites=True)
        pruned_graph.structured_input_signature = input_signature
        pruned_fn = WrappedFunction(pruned_graph,
                                    variable_holder=self._variable_holder)
        pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
        # TODO(kathywu): Enable keyword arguments if an input signature is specified
        pruned_fn._arg_keywords = [tensor.op.name for tensor in flat_feeds]  # pylint: disable=protected-access
        return pruned_fn
示例#21
0
  def __init__(self,
               initial_value=None,
               trainable=None,
               caching_device=None,
               name=None,
               dtype=None,
               constraint=None,
               add_initializers_to=None,
               lifted_initializer_graph=None,
               synchronization=None,
               aggregation=None,
               **unused_kwargs):
    """Creates a variable.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, GradientTapes automatically watch uses of this
        Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).
      constraint: An optional projection function to be applied to the variable
        after being updated by an `Optimizer` (e.g. used to implement norm
        constraints or value constraints for layer weights). The function must
        take as input the unprojected Tensor representing the value of the
        variable and return the Tensor for the projected value
        (which must have the same shape). Constraints are not safe to
        use when doing asynchronous distributed training.
      add_initializers_to: if not None and not in legacy graph mode, the
        initializer tensor will be added to this map in addition to adding the
        assignment to the function.
      lifted_initializer_graph: FuncGraph to try to lift initializers to.
      synchronization: Indicates when a distributed a variable will be
        aggregated. Accepted values are constants defined in the class
        `tf.VariableSynchronization`. By default the synchronization is set to
        `AUTO` and the current `DistributionStrategy` chooses
        when to synchronize. If `synchronization` is set to `ON_READ`,
        `trainable` must not be set to `True`.
      aggregation: Indicates how a distributed variable will be aggregated.
        Accepted values are constants defined in the class
        `tf.VariableAggregation`.

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
      RuntimeError: If called outside of a function definition.
    """
    if not ops.inside_function():
      # If we've been init_scope()d out of the function definition nothing to do
      # here; we can't really do the capturing or conditional logic.
      resource_variable_ops.ResourceVariable.__init__(
          self, initial_value=initial_value, trainable=trainable,
          caching_device=caching_device, name=name, dtype=dtype,
          constraint=constraint)
      return
    if initial_value is None:
      raise ValueError("initial_value must be specified.")
    init_from_fn = callable(initial_value)

    if constraint is not None and not callable(constraint):
      raise ValueError("The `constraint` argument must be a callable.")

    if isinstance(initial_value, trackable.CheckpointInitialValue):
      self._maybe_initialize_trackable()
      self._update_uid = initial_value.checkpoint_position.restore_uid
      initial_value = initial_value.wrapped_value

    with ops.name_scope(name, "Variable", []
                        if init_from_fn else [initial_value]) as name:
      with ops.name_scope("Initializer"), ops.device(None):
        initial_value = ops.convert_to_tensor(
            initial_value() if init_from_fn else initial_value,
            name="initial_value", dtype=dtype)
      assert initial_value is not None

      # Use the constructor for UninitializedVariable to start.
      super(UnliftedInitializerVariable, self).__init__(
          trainable=trainable,
          caching_device=caching_device,
          name=name,
          shape=initial_value.shape,
          dtype=initial_value.dtype,
          constraint=constraint,
          synchronization=synchronization,
          aggregation=aggregation,
          extra_handle_data=initial_value,
          **unused_kwargs)

      if self._in_graph_mode:
        with ops.init_scope():
          outer_graph = ops.get_default_graph()
        func_graph = ops.get_default_graph()
        function_placeholders = (
            func_graph.inputs + func_graph.internal_captures)
        placeholder_ops = set(
            [tensor.op for tensor in function_placeholders])
        lifted_initializer = lift_to_graph.lift_to_graph(
            [initial_value], outer_graph,
            disallowed_placeholders=placeholder_ops)[initial_value]
        with ops.init_scope():
          self._initial_value = lifted_initializer
          with ops.name_scope("IsInitialized"):
            self._is_initialized_op = (
                resource_variable_ops.var_is_initialized_op(self._handle))
          if initial_value is not None:
            with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
              self._initializer_op = resource_variable_ops.assign_variable_op(
                  self._handle, lifted_initializer, name=n)
      else:
        if add_initializers_to is not None:
          add_initializers_to[self] = initial_value
        def assign_fn():
          with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
            resource_variable_ops.assign_variable_op(
                self._handle,
                initial_value,
                name=n)
            # Returning values to keep tf.cond happy.
          return ops.convert_to_tensor(1)
        def not_assign_fn():
          return ops.convert_to_tensor(0)
        # Note: this cond is always guaranteed to run because we're inside a
        # defun which will insert automatic control dependencies.
        control_flow_ops.cond(
            resource_variable_ops.var_is_initialized_op(self._handle),
            not_assign_fn, assign_fn)
示例#22
0
  def prune(self, feeds, fetches, name=None, input_signature=None):
    """Extract a subgraph of this function's underlying graph.

    Wraps the subgraph in a new `WrappedFunction` object.

    Args:
      feeds: Input tensors to the subgraph to extract, as `Tensor` objects.
      fetches: Possibly-nested Python data structure containing information
        about outputs of the target subgraph. Each entry can either be a
        `Tensor` object (for data outputs), an `Operation` object (for control
        outputs), or a `TensorInfo` proto. Any additional shape/dtype
        information provided in a `TensorInfo` and not present in the original
        graph will be added to the returned subgraph.
      name: (optional) Name to give to the underlying `FuncGraph` of the
        returned object. If no name is provided, the graph's name will be
        `"pruned"`.
      input_signature: (optional) possibly-nested Python data structure
        containing `TensorSpec` objects, with which to populate the returned
        functions's `FuncGraph`'s `structured_input_signature` field.

    Returns:
      A new `WrappedFunction` object containing a copy of the portion of this
        object's graph that goes from `feeds` to `fetches`.
    """
    # TODO(b/129646028): Add support for CompositeTensors.
    name = name or "pruned"
    feeds = nest.map_structure(self.graph.as_graph_element, feeds)
    flat_feeds = nest.flatten(feeds)
    for f in flat_feeds:
      if not isinstance(f, ops.Tensor):
        raise ValueError("Feeds must be tensors.")

    # Ignoring all feeds that are captures allows prune to be called
    # using wrapped_func.inputs even when it uses variables
    internal_captures = self.graph.internal_captures
    flat_feeds = [f for f in flat_feeds if f not in internal_captures]

    operation_fetches = []
    tensor_fetches = []
    tensor_infos = []

    def _fetch_preprocesing_callback(f):
      """Extract out lists of ops, tensors, and tensor type info.

      Turns TensorInfos into Tensors in the original fetches structure.

      Args:
        f: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string
          identifying a Tensor or Operation.

      Returns:
        `f` converted to a Tensor.
      """
      if isinstance(f, ops.Operation):
        operation_fetches.append(f)
        return f
      elif isinstance(f, meta_graph_pb2.TensorInfo):
        tensor_infos.append(f)
        decoded = _get_element_from_tensor_info(f, self._func_graph)
        if tensor_util.is_tensor(decoded):
          tensor_fetches.append(decoded)
        else:
          operation_fetches.append(decoded)
        return decoded
      elif isinstance(f, ops.Tensor):
        tensor_fetches.append(f)
        return f
      else:
        graph_element = self.graph.as_graph_element(f)
        return _fetch_preprocesing_callback(graph_element)

    fetches = nest.map_structure(_fetch_preprocesing_callback, fetches)

    for f in flat_feeds + tensor_fetches + operation_fetches:
      if f.graph is not self._func_graph:
        raise ValueError("Can only prune function whose feeds and fetches "
                         "are from this graph (%s). Input %s is from graph %s" %
                         (self._func_graph, f, f.graph))
    with self._func_graph.as_default():
      pruned_graph = func_graph.FuncGraph(name)
    lift_map = lift_to_graph.lift_to_graph(
        operation_fetches + tensor_fetches,
        pruned_graph,
        sources=flat_feeds + internal_captures)
    pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
    pruned_graph.control_outputs.extend(
        [lift_map[operation] for operation in operation_fetches])
    for external_capture, internal_capture in self.graph.captures.items():
      pruned_graph.captures[external_capture] = lift_map[internal_capture]
    pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
    pruned_graph.inputs.extend(pruned_graph.captures.values())
    for ti in tensor_infos:
      if ti.WhichOneof("encoding") == "name":  # Dense tensors only
        t = pruned_graph.as_graph_element(ti.name)
        if tensor_util.is_tensor(t):
          t.set_shape(tensor_shape.TensorShape(ti.tensor_shape))
    # pylint: disable=protected-access
    for f in self.graph._functions.values():
      pruned_graph._add_function(f)
    # pylint: enable=protected-access

    pruned_graph.variables = self.graph.variables

    def _structured_output_mapping(fetched):
      lifted = lift_map[fetched]
      if isinstance(lifted, ops.Operation):
        return None
      return lifted

    pruned_graph.structured_outputs = nest.map_structure(
        _structured_output_mapping, fetches)
    pruned_graph.structured_input_signature = input_signature
    pruned_fn = WrappedFunction(
        pruned_graph, variable_holder=self._variable_holder)
    pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
    # TODO(kathywu): Enable keyword arguments if an input signature is specified
    pruned_fn._arg_keywords = [tensor.op.name for tensor in flat_feeds]  # pylint: disable=protected-access
    return pruned_fn
示例#23
0
 def initialize_variables():
   for v, init in initializer_map.items():
     v.assign(lift_to_graph.lift_to_graph(
         [init], ops.get_default_graph())[init])
示例#24
0
 def initialize_variables():
   for v, init in initializers:
     v.assign(
         lift_to_graph.lift_to_graph([init], ops.get_default_graph())[init],
         read_value=False)