Exemplo n.º 1
0
    def test_shared_variable(self):
        x = gen_resource_variable_ops.var_handle_op(dtype=tf.float32,
                                                    shape=(1, 2),
                                                    shared_name="variable_1")
        gen_resource_variable_ops.assign_variable_op(x,
                                                     tf.constant([[1.0, 2.0]]))
        y = gen_resource_variable_ops.var_handle_op(dtype=tf.float32,
                                                    shape=(1, 2),
                                                    shared_name="variable_1")
        gen_resource_variable_ops.assign_variable_op(y,
                                                     tf.constant([[2.0, 3.0]]))
        read_x = gen_resource_variable_ops.read_variable_op(x,
                                                            dtype=tf.float32)
        read_y = gen_resource_variable_ops.read_variable_op(y,
                                                            dtype=tf.float32)
        self.assertTrue(tensor_equal(read_x, read_y))

        x = gen_resource_variable_ops.var_handle_op(
            dtype=tf.float32, shape=(1, 2), shared_name=context.shared_name())
        gen_resource_variable_ops.assign_variable_op(x,
                                                     tf.constant([[1.0, 2.0]]))
        y = gen_resource_variable_ops.var_handle_op(
            dtype=tf.float32, shape=(1, 2), shared_name=context.shared_name())
        gen_resource_variable_ops.assign_variable_op(y,
                                                     tf.constant([[2.0, 3.0]]))
        read_x = gen_resource_variable_ops.read_variable_op(x,
                                                            dtype=tf.float32)
        read_y = gen_resource_variable_ops.read_variable_op(y,
                                                            dtype=tf.float32)
        self.assertFalse(tensor_equal(read_x, read_y))
Exemplo n.º 2
0
def create_file_writer_v2(logdir,
                          max_queue=None,
                          flush_millis=None,
                          filename_suffix=None,
                          name=None):
    """Creates a summary file writer for the given log directory.

  Args:
    logdir: a string specifying the directory in which to write an event file.
    max_queue: the largest number of summaries to keep in a queue; will
     flush once the queue gets bigger than this. Defaults to 10.
    flush_millis: the largest interval between flushes. Defaults to 120,000.
    filename_suffix: optional suffix for the event file name. Defaults to `.v2`.
    name: a name for the op that creates the writer.

  Returns:
    A SummaryWriter object.
  """
    if logdir is None:
        raise ValueError("logdir cannot be None")
    inside_function = ops.inside_function()
    with ops.name_scope(name,
                        "create_file_writer") as scope, ops.device("cpu:0"):
        # Run init inside an init_scope() to hoist it out of tf.functions.
        with ops.init_scope():
            if context.executing_eagerly():
                _check_create_file_writer_args(inside_function,
                                               logdir=logdir,
                                               max_queue=max_queue,
                                               flush_millis=flush_millis,
                                               filename_suffix=filename_suffix)
            logdir = ops.convert_to_tensor(logdir, dtype=dtypes.string)
            if max_queue is None:
                max_queue = constant_op.constant(10)
            if flush_millis is None:
                flush_millis = constant_op.constant(2 * 60 * 1000)
            if filename_suffix is None:
                filename_suffix = constant_op.constant(".v2")
            # Prepend the PID and a process-local UID to the filename suffix to avoid
            # filename collisions within the machine (the filename already contains
            # the hostname to avoid cross-machine collisions).
            unique_prefix = constant_op.constant(".%s.%s" %
                                                 (os.getpid(), ops.uid()))
            filename_suffix = unique_prefix + filename_suffix
            # Use a unique shared_name to prevent resource sharing.
            if context.executing_eagerly():
                shared_name = context.shared_name()
            else:
                shared_name = ops.name_from_scope_name(scope)  # pylint: disable=protected-access
            return ResourceSummaryWriter(
                shared_name=shared_name,
                init_op_fn=functools.partial(
                    gen_summary_ops.create_summary_file_writer,
                    logdir=logdir,
                    max_queue=max_queue,
                    flush_millis=flush_millis,
                    filename_suffix=filename_suffix),
                name=name,
                v2=True,
                metadata={"logdir": logdir})
Exemplo n.º 3
0
def create_file_writer_v2(logdir,
                          max_queue=None,
                          flush_millis=None,
                          filename_suffix=None,
                          name=None):
  """Creates a summary file writer for the given log directory.

  Args:
    logdir: a string specifying the directory in which to write an event file.
    max_queue: the largest number of summaries to keep in a queue; will
     flush once the queue gets bigger than this. Defaults to 10.
    flush_millis: the largest interval between flushes. Defaults to 120,000.
    filename_suffix: optional suffix for the event file name. Defaults to `.v2`.
    name: a name for the op that creates the writer.

  Returns:
    A SummaryWriter object.
  """
  if logdir is None:
    raise ValueError("logdir cannot be None")
  inside_function = ops.inside_function()
  with ops.name_scope(name, "create_file_writer") as scope, ops.device("cpu:0"):
    # Run init inside an init_scope() to hoist it out of tf.functions.
    with ops.init_scope():
      if context.executing_eagerly():
        _check_create_file_writer_args(
            inside_function,
            logdir=logdir,
            max_queue=max_queue,
            flush_millis=flush_millis,
            filename_suffix=filename_suffix)
      logdir = ops.convert_to_tensor(logdir, dtype=dtypes.string)
      if max_queue is None:
        max_queue = constant_op.constant(10)
      if flush_millis is None:
        flush_millis = constant_op.constant(2 * 60 * 1000)
      if filename_suffix is None:
        filename_suffix = constant_op.constant(".v2")
      # Prepend the PID and a process-local UID to the filename suffix to avoid
      # filename collisions within the machine (the filename already contains
      # the hostname to avoid cross-machine collisions).
      unique_prefix = constant_op.constant(".%s.%s" % (os.getpid(), ops.uid()))
      filename_suffix = unique_prefix + filename_suffix
      # Use a unique shared_name to prevent resource sharing.
      if context.executing_eagerly():
        shared_name = context.shared_name()
      else:
        shared_name = ops.name_from_scope_name(scope)  # pylint: disable=protected-access
      return ResourceSummaryWriter(
          shared_name=shared_name,
          init_op_fn=functools.partial(
              gen_summary_ops.create_summary_file_writer,
              logdir=logdir,
              max_queue=max_queue,
              flush_millis=flush_millis,
              filename_suffix=filename_suffix),
          name=name,
          v2=True)
Exemplo n.º 4
0
 def create_fn():
   # Use unique shared_name to prevent resource sharing in eager mode, but
   # otherwise use a fixed shared_name to allow SavedModel TF 1.x loading.
   if context.executing_eagerly():
     shared_name = context.shared_name()
   else:
     shared_name = ops.name_from_scope_name(scope)  # pylint: disable=protected-access
   return gen_summary_ops.summary_writer(
       shared_name=shared_name, name=name)
Exemplo n.º 5
0
  def __init__(self, server_addresses):
    """Creates and starts the gRPC server.

    Args:
      server_addresses: A list of strings containing one or more server
        addresses.
    """
    if not tf.executing_eagerly():
      raise ValueError("Only eager mode is currently supported.")

    self._handle = gen_grpc_ops.grpc_server_resource_handle_op(
        shared_name=context.shared_name(None))
    # Delete the resource when this object is deleted.
    self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
        handle=self._handle, handle_device=context.context().device_name)
    gen_grpc_ops.create_grpc_server(self._handle, server_addresses)

    # References to tf.Variable's, etc. used in a tf.function to prevent them
    # from being deallocated.
    self._keep_alive = []
Exemplo n.º 6
0
  def __init__(self, server_address):
    """Creates and starts the gRPC client.

    Args:
      server_address: A string containing the server address.
    """
    if not tf.executing_eagerly():
      raise ValueError("Only eager mode is currently supported.")

    self._handle = gen_grpc_ops.grpc_client_resource_handle_op(
        shared_name=context.shared_name(None))
    self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
        handle=self._handle, handle_device=context.context().device_name)
    method_signatures = gen_grpc_ops.create_grpc_client(self._handle,
                                                        server_address).numpy()
    m = service_pb2.MethodOutputSignature()
    v = struct_pb2.StructuredValue()
    for sig in method_signatures:
      assert m.ParseFromString(sig)
      decoder = nested_structure_coder.StructureCoder()
      assert v.ParseFromString(m.output_specs)
      decoded_output_specs = decoder.decode_proto(v)
      self._add_method(m.name, decoded_output_specs)
Exemplo n.º 7
0
    def __init__(
            self,  # pylint: disable=super-init-not-called
            initial_value=None,
            trainable=None,
            caching_device=None,
            name=None,
            dtype=None,
            constraint=None,
            add_initializers_to=None,
            lifted_initializer_graph=None,
            synchronization=None,
            aggregation=None,
            **unused_kwargs):
        """Creates a variable.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, GradientTapes automatically watch uses of this
        Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).
      constraint: An optional projection function to be applied to the variable
        after being updated by an `Optimizer` (e.g. used to implement norm
        constraints or value constraints for layer weights). The function must
        take as input the unprojected Tensor representing the value of the
        variable and return the Tensor for the projected value
        (which must have the same shape). Constraints are not safe to
        use when doing asynchronous distributed training.
      add_initializers_to: if not None and not in legacy graph mode, the
        initializer tensor will be added to this map in addition to adding the
        assignment to the function.
      lifted_initializer_graph: FuncGraph to try to lift initializers to.
      synchronization: Indicates when a distributed a variable will be
        aggregated. Accepted values are constants defined in the class
        `tf.VariableSynchronization`. By default the synchronization is set to
        `AUTO` and the current `DistributionStrategy` chooses
        when to synchronize. If `synchronization` is set to `ON_READ`,
        `trainable` must not be set to `True`.
      aggregation: Indicates how a distributed variable will be aggregated.
        Accepted values are constants defined in the class
        `tf.VariableAggregation`.

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
      RuntimeError: If called outside of a function definition.
    """
        if not ops.inside_function():
            # If we've been init_scope()d out of the function definition nothing to do
            # here; we can't really do the capturing or conditional logic.
            resource_variable_ops.ResourceVariable.__init__(
                self,
                initial_value=initial_value,
                trainable=trainable,
                caching_device=caching_device,
                name=name,
                dtype=dtype,
                constraint=constraint)
            return
        with ops.init_scope():
            self._in_graph_mode = not context.executing_eagerly()
        if initial_value is None:
            raise ValueError("initial_value must be specified.")
        init_from_fn = callable(initial_value)

        if constraint is not None and not callable(constraint):
            raise ValueError("The `constraint` argument must be a callable.")

        if isinstance(initial_value, trackable.CheckpointInitialValue):
            self._maybe_initialize_trackable()
            self._update_uid = initial_value.checkpoint_position.restore_uid
            initial_value = initial_value.wrapped_value

        synchronization, aggregation, trainable = (
            variables.validate_synchronization_aggregation_trainable(
                synchronization, aggregation, trainable, name))
        self._trainable = trainable
        self._synchronization = synchronization
        self._aggregation = aggregation
        self._save_slice_info = None
        self._initial_value = None
        self._initializer_op = None
        self._is_initialized_op = None
        self._graph_element = None
        self._cached_value = None
        # Store the graph key so optimizers know how to only retrieve variables from
        # this graph. Guaranteed to be the same as the eager graph_key.
        self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
        with ops.name_scope(name, "Variable",
                            [] if init_from_fn else [initial_value]) as name:
            # pylint: disable=protected-access
            with ops.init_scope():
                handle_name = ops.name_from_scope_name(name)
                unique_id = "%s_%d" % (handle_name, ops.uid())
                shared_name = context.shared_name(unique_id)
            with ops.name_scope("Initializer"), ops.device(None):
                initial_value = ops.convert_to_tensor(
                    initial_value() if init_from_fn else initial_value,
                    name="initial_value",
                    dtype=dtype)
            with ops.init_scope():
                self._handle = resource_variable_ops.eager_safe_variable_handle(
                    initial_value=initial_value,
                    shared_name=shared_name,
                    name=name,
                    graph_mode=self._in_graph_mode)
            self._shape = initial_value.shape
            self._unique_id = unique_id
            self._handle_name = handle_name + ":0"
            self._dtype = initial_value.dtype.base_dtype
            self._constraint = constraint
            assert initial_value is not None
            if self._in_graph_mode:
                with ops.init_scope():
                    outer_graph = ops.get_default_graph()
                func_graph = ops.get_default_graph()
                function_placeholders = (func_graph.inputs +
                                         func_graph.internal_captures)
                placeholder_ops = set(
                    [tensor.op for tensor in function_placeholders])
                lifted_initializer = lift_to_graph.lift_to_graph(
                    [initial_value],
                    outer_graph,
                    disallowed_placeholders=placeholder_ops)[initial_value]
                with ops.init_scope():
                    self._initial_value = lifted_initializer
                    with ops.name_scope("IsInitialized"):
                        self._is_initialized_op = (
                            resource_variable_ops.var_is_initialized_op(
                                self._handle))
                    if initial_value is not None:
                        with ops.name_scope("Assign") as n, ops.colocate_with(
                                self._handle):
                            self._initializer_op = resource_variable_ops.assign_variable_op(
                                self._handle, lifted_initializer, name=n)
                    with ops.name_scope("Read"), ops.colocate_with(
                            self._handle):
                        # Manually assign reads to the handle's device to avoid log
                        # messages.
                        with ops.device(self._handle.device):
                            value = self._read_variable_op()
                        self._graph_element = value
                    ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self)
            else:
                if add_initializers_to is not None:
                    add_initializers_to[self] = initial_value

                def assign_fn():
                    with ops.name_scope("Assign") as n, ops.colocate_with(
                            self._handle):
                        resource_variable_ops.assign_variable_op(self._handle,
                                                                 initial_value,
                                                                 name=n)
                        # Returning values to keep tf.cond happy.
                    return ops.convert_to_tensor(1)

                def not_assign_fn():
                    return ops.convert_to_tensor(0)

                # Note: this cond is always guaranteed to run because we're inside a
                # defun which will insert automatic control dependencies.
                control_flow_ops.cond(
                    resource_variable_ops.var_is_initialized_op(self._handle),
                    not_assign_fn, assign_fn)

        # After the handle has been created, set up a way to clean it up when
        # executing eagerly. We'll hold the only reference to the deleter, so that
        # when this object is garbage collected the deleter will be too. This
        # means ResourceVariables can be part of reference cycles without those
        # cycles being uncollectable.
        if not self._in_graph_mode:
            self._handle_deleter = resource_variable_ops.EagerResourceDeleter(
                handle=self._handle, handle_device=self._handle.device)
        self._cached_shape_as_list = None
Exemplo n.º 8
0
    def __init__(self,
                 dataset,
                 devices,
                 max_buffer_size=1,
                 prefetch_buffer_size=1,
                 source_device="/cpu:0"):
        """Constructs a MultiDeviceIterator.

    Args:
      dataset: The input dataset to be iterated over.
      devices: The list of devices to fetch data to.
      max_buffer_size: Maximum size of the host side per device buffer to keep.
      prefetch_buffer_size: if > 1, then we setup a buffer on each device to
        prefetch into.
      source_device: The host device to place the `dataset` on.  In order to
        prevent deadlocks, if the prefetch_buffer_size is greater than the
        max_buffer_size, we set the max_buffer_size to prefetch_buffer_size.
    """
        options = dataset_ops.Options()
        options.experimental_distribute.num_devices = len(devices)
        dataset = dataset.with_options(options)
        self._dataset = dataset._apply_options()  # pylint: disable=protected-access
        self._experimental_slack = dataset.options().experimental_slack
        self._devices = devices
        self._source_device = source_device
        self._source_device_tensor = ops.convert_to_tensor(source_device)
        self._max_buffer_size = max_buffer_size
        self._prefetch_buffer_size = prefetch_buffer_size

        if self._prefetch_buffer_size > self._max_buffer_size:
            self._max_buffer_size = self._prefetch_buffer_size

        # Create the MultiDeviceIterator.
        with ops.device(self._source_device):
            # TODO(b/121378567): Get rid of this shared_name hack.
            shared_name = ""
            if context.executing_eagerly():
                shared_name = context.shared_name()
            self._multi_device_iterator_resource = (
                gen_dataset_ops.multi_device_iterator(
                    devices=self._devices,
                    shared_name=shared_name,
                    container="",
                    **self._dataset._flat_structure))  # pylint: disable=protected-access
            if context.executing_eagerly():
                # Delete the resource when this object is deleted
                self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
                    handle=self._multi_device_iterator_resource,
                    handle_device=self._source_device)

            # The incarnation ID is used to ensure consistency between the per-device
            # iterators and the multi-device iterator.
            self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
                self._dataset._variant_tensor,  # pylint: disable=protected-access
                self._multi_device_iterator_resource,
                max_buffer_size=self._max_buffer_size)

        self._prototype_device_datasets = []
        for i, device in enumerate(self._devices):
            with ops.device(device):
                ds = _PerDeviceGenerator(i,
                                         self._multi_device_iterator_resource,
                                         self._incarnation_id,
                                         self._source_device_tensor,
                                         self._dataset.element_spec)
                self._prototype_device_datasets.append(ds)

        # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
        # initialize the device side of the pipeline. This would allow the
        # MultiDeviceIterator to choose, for example, to move some transformations
        # into the device side from its input. It might be useful in rewriting.
        # Create the per device iterators.
        self._device_iterators = []
        for i, device in enumerate(self._devices):
            with ops.device(device):
                ds = _create_device_dataset(self._prototype_device_datasets[i],
                                            self._incarnation_id,
                                            self._prefetch_buffer_size,
                                            self._experimental_slack)
                if context.executing_eagerly():
                    self._device_iterators.append(
                        dataset_ops.make_one_shot_iterator(ds))
                else:
                    self._device_iterators.append(
                        dataset_ops.make_initializable_iterator(ds))

        if not context.executing_eagerly():
            device_iterator_initializers = [
                iterator.initializer for iterator in self._device_iterators
            ]
            self._initializer = control_flow_ops.group(
                *device_iterator_initializers)
Exemplo n.º 9
0
  def __init__(self,  # pylint: disable=super-init-not-called
               initial_value=None,
               trainable=None,
               caching_device=None,
               name=None,
               dtype=None,
               constraint=None,
               add_initializers_to=None,
               lifted_initializer_graph=None,
               **unused_kwargs):
    """Creates a variable.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, GradientTapes automatically watch uses of this
        Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).
      constraint: An optional projection function to be applied to the variable
        after being updated by an `Optimizer` (e.g. used to implement norm
        constraints or value constraints for layer weights). The function must
        take as input the unprojected Tensor representing the value of the
        variable and return the Tensor for the projected value
        (which must have the same shape). Constraints are not safe to
        use when doing asynchronous distributed training.
      add_initializers_to: if not None and not in legacy graph mode, the
        initializer tensor will be added to this map in addition to adding the
        assignment to the function.
      lifted_initializer_graph: FuncGraph to try to lift initializers to.

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
      RuntimeError: If called outside of a function definition.
    """
    if not ops.inside_function():
      # If we've been init_scope()d out of the function definition nothing to do
      # here; we can't really do the capturing or conditional logic.
      resource_variable_ops.ResourceVariable.__init__(
          self, initial_value=initial_value, trainable=trainable,
          caching_device=caching_device, name=name, dtype=dtype,
          constraint=constraint)
      return
    with ops.init_scope():
      self._in_graph_mode = not context.executing_eagerly()
    if initial_value is None:
      raise ValueError("initial_value must be specified.")
    init_from_fn = callable(initial_value)

    if constraint is not None and not callable(constraint):
      raise ValueError("The `constraint` argument must be a callable.")

    if isinstance(initial_value, trackable.CheckpointInitialValue):
      self._maybe_initialize_trackable()
      self._update_uid = initial_value.checkpoint_position.restore_uid
      initial_value = initial_value.wrapped_value

    if trainable is None:
      trainable = True
    self._trainable = trainable
    self._save_slice_info = None
    self._initial_value = None
    self._initializer_op = None
    self._is_initialized_op = None
    self._graph_element = None
    self._cached_value = None
    # Store the graph key so optimizers know how to only retrieve variables from
    # this graph. Guaranteed to be the same as the eager graph_key.
    self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
    with ops.name_scope(name, "Variable", []
                        if init_from_fn else [initial_value]) as name:
      # pylint: disable=protected-access
      with ops.init_scope():
        handle_name = ops._name_from_scope_name(name)
        unique_id = "%s_%d" % (handle_name, ops.uid())
        shared_name = context.shared_name(unique_id)
      with ops.name_scope("Initializer"), ops.device(None):
        initial_value = ops.convert_to_tensor(
            initial_value() if init_from_fn else initial_value,
            name="initial_value", dtype=dtype)
      with ops.init_scope():
        self._handle = resource_variable_ops.eager_safe_variable_handle(
            initial_value=initial_value,
            shared_name=shared_name,
            name=name,
            graph_mode=self._in_graph_mode)
      self._shape = initial_value.shape
      self._unique_id = unique_id
      self._handle_name = handle_name + ":0"
      self._dtype = initial_value.dtype.base_dtype
      self._constraint = constraint
      assert initial_value is not None
      if self._in_graph_mode:
        with ops.init_scope():
          outer_graph = ops.get_default_graph()
        func_graph = ops.get_default_graph()
        function_placeholders = (
            func_graph.inputs + func_graph.internal_captures)
        placeholder_ops = set(
            [tensor.op for tensor in function_placeholders])
        lifted_initializer = lift_to_graph.lift_to_graph(
            [initial_value], outer_graph,
            disallowed_placeholders=placeholder_ops)[initial_value]
        with ops.init_scope():
          self._initial_value = lifted_initializer
          with ops.name_scope("IsInitialized"):
            self._is_initialized_op = (
                resource_variable_ops.var_is_initialized_op(self._handle))
          if initial_value is not None:
            with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
              self._initializer_op = resource_variable_ops.assign_variable_op(
                  self._handle, lifted_initializer, name=n)
          with ops.name_scope("Read"), ops.colocate_with(self._handle):
            # Manually assign reads to the handle's device to avoid log
            # messages.
            with ops.device(self._handle.device):
              value = self._read_variable_op()
            self._graph_element = value
          ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self)
      else:
        if add_initializers_to is not None:
          add_initializers_to[self] = initial_value
        def assign_fn():
          with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
            resource_variable_ops.assign_variable_op(
                self._handle,
                initial_value,
                name=n)
            # Returning values to keep tf.cond happy.
          return ops.convert_to_tensor(1)
        def not_assign_fn():
          return ops.convert_to_tensor(0)
        # Note: this cond is always guaranteed to run because we're inside a
        # defun which will insert automatic control dependencies.
        control_flow_ops.cond(
            resource_variable_ops.var_is_initialized_op(self._handle),
            not_assign_fn, assign_fn)

    # After the handle has been created, set up a way to clean it up when
    # executing eagerly. We'll hold the only reference to the deleter, so that
    # when this object is garbage collected the deleter will be too. This
    # means ResourceVariables can be part of reference cycles without those
    # cycles being uncollectable.
    if not self._in_graph_mode:
      self._handle_deleter = resource_variable_ops.EagerResourceDeleter(
          handle=self._handle, handle_device=self._handle.device)
    self._cached_shape_as_list = None
Exemplo n.º 10
0
    def __init__(self,
                 dataset,
                 devices,
                 max_buffer_size=1,
                 prefetch_buffer_size=1,
                 source_device="/cpu:0"):
        """Constructs a MultiDeviceIterator.

    Args:
      dataset: The input dataset to be iterated over.
      devices: The list of devices to fetch data to.
      max_buffer_size: Maximum size of the host side per device buffer to keep.
      prefetch_buffer_size: if > 1, then we setup a buffer on each device
        to prefetch into.
      source_device: The host device to place the `dataset` on.

    Raises:
      RuntimeError: If run in Eager mode.
    """
        self._dataset = dataset._apply_options()  # pylint: disable=protected-access
        self._devices = devices
        self._source_device = source_device
        self._source_device_tensor = ops.convert_to_tensor(source_device)

        # Create the MultiDeviceIterator.
        with ops.device(self._source_device):
            # TODO(b/121378567): Get rid of this shared_name hack.
            shared_name = ""
            if context.executing_eagerly():
                shared_name = context.shared_name()
            self._multi_device_iterator_resource = (
                gen_dataset_ops.multi_device_iterator(
                    devices=self._devices,
                    shared_name=shared_name,
                    container="",
                    **dataset_ops.flat_structure(dataset)))
            if context.executing_eagerly():
                # Delete the resource when this object is deleted
                self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
                    handle=self._multi_device_iterator_resource,
                    handle_device=self._source_device)

            # The incarnation ID is used to ensure consistency between the per-device
            # iterators and the multi-device iterator.
            self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
                self._dataset._variant_tensor,  # pylint: disable=protected-access
                self._multi_device_iterator_resource,
                max_buffer_size=max_buffer_size)

        # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
        # initialize the device side of the pipeline. This would allow the
        # MultiDeviceIterator to choose, for example, to move some transformations
        # into the device side from its input. It might be useful in rewriting.
        # Create the per device iterators.
        self._device_iterators = []
        for i, device in enumerate(self._devices):
            with ops.device(device):
                ds = _PerDeviceGenerator(i,
                                         self._multi_device_iterator_resource,
                                         self._incarnation_id,
                                         self._source_device_tensor,
                                         dataset._element_structure)  # pylint: disable=protected-access
                if prefetch_buffer_size > 0:
                    ds = ds.prefetch(prefetch_buffer_size)
                # TODO(jsimsa): Enable auto-tuning and optimizations when supported for
                # non-CPU devices.
                options = dataset_ops.Options()
                options.experimental_autotune = False
                options.experimental_optimization.apply_default_optimizations = False
                ds = ds.with_options(options)
                if context.executing_eagerly():
                    self._device_iterators.append(
                        dataset_ops.make_one_shot_iterator(ds))
                else:
                    self._device_iterators.append(
                        dataset_ops.make_initializable_iterator(ds))

        if not context.executing_eagerly():
            device_iterator_initializers = [
                iterator.initializer for iterator in self._device_iterators
            ]
            self._initializer = control_flow_ops.group(
                *device_iterator_initializers)
  def __init__(self,
               dataset,
               devices,
               max_buffer_size=1,
               prefetch_buffer_size=1,
               source_device="/cpu:0"):
    """Constructs a MultiDeviceIterator.

    Args:
      dataset: The input dataset to be iterated over.
      devices: The list of devices to fetch data to.
      max_buffer_size: Maximum size of the host side per device buffer to keep.
      prefetch_buffer_size: if > 1, then we setup a buffer on each device
        to prefetch into.
      source_device: The host device to place the `dataset` on.

      In order to prevent deadlocks, if the prefetch_buffer_size is greater
      than the max_buffer_size, we set the max_buffer_size to
      prefetch_buffer_size.

    Raises:
      RuntimeError: If run in Eager mode.
    """
    self._dataset = dataset._apply_options()  # pylint: disable=protected-access
    self._devices = devices
    self._source_device = source_device
    self._source_device_tensor = ops.convert_to_tensor(source_device)
    self._max_buffer_size = max_buffer_size
    self._prefetch_buffer_size = prefetch_buffer_size

    if self._prefetch_buffer_size > self._max_buffer_size:
      self._max_buffer_size = self._prefetch_buffer_size

    # Create the MultiDeviceIterator.
    with ops.device(self._source_device):
      # TODO(b/121378567): Get rid of this shared_name hack.
      shared_name = ""
      if context.executing_eagerly():
        shared_name = context.shared_name()
      self._multi_device_iterator_resource = (
          gen_dataset_ops.multi_device_iterator(
              devices=self._devices,
              shared_name=shared_name,
              container="",
              **dataset_ops.flat_structure(self._dataset)))
      if context.executing_eagerly():
        # Delete the resource when this object is deleted
        self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
            handle=self._multi_device_iterator_resource,
            handle_device=self._source_device)

      # The incarnation ID is used to ensure consistency between the per-device
      # iterators and the multi-device iterator.
      self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
          self._dataset._variant_tensor,  # pylint: disable=protected-access
          self._multi_device_iterator_resource,
          max_buffer_size=self._max_buffer_size)

    self._prototype_device_datasets = []
    for i, device in enumerate(self._devices):
      with ops.device(device):
        ds = _PerDeviceGenerator(
            i, self._multi_device_iterator_resource, self._incarnation_id,
            self._source_device_tensor, self._dataset._element_structure)  # pylint: disable=protected-access
        self._prototype_device_datasets.append(ds)

    # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
    # initialize the device side of the pipeline. This would allow the
    # MultiDeviceIterator to choose, for example, to move some transformations
    # into the device side from its input. It might be useful in rewriting.
    # Create the per device iterators.
    self._device_iterators = []
    for i, device in enumerate(self._devices):
      with ops.device(device):
        ds = self._create_device_dataset(i)
        if context.executing_eagerly():
          self._device_iterators.append(dataset_ops.make_one_shot_iterator(ds))
        else:
          self._device_iterators.append(
              dataset_ops.make_initializable_iterator(ds))

    if not context.executing_eagerly():
      device_iterator_initializers = [
          iterator.initializer for iterator in self._device_iterators
      ]
      self._initializer = control_flow_ops.group(*device_iterator_initializers)
Exemplo n.º 12
0
    def __init__(  # pylint: disable=super-init-not-called
            self,
            trainable=None,
            caching_device=None,
            name=None,
            shape=None,
            dtype=None,
            constraint=None,
            synchronization=None,
            aggregation=None,
            extra_handle_data=None,
            distribute_strategy=None,
            **unused_kwargs):
        """Creates the variable handle.

    Args:
      trainable: If `True`, GradientTapes automatically watch uses of this
        Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      shape: The variable's shape.
      dtype: The variable's dtype.
      constraint: An optional projection function to be applied to the variable
        after being updated by an `Optimizer` (e.g. used to implement norm
        constraints or value constraints for layer weights). The function must
        take as input the unprojected Tensor representing the value of the
        variable and return the Tensor for the projected value
        (which must have the same shape). Constraints are not safe to
        use when doing asynchronous distributed training.
      synchronization: Indicates when a distributed a variable will be
        aggregated. Accepted values are constants defined in the class
        `tf.VariableSynchronization`. By default the synchronization is set to
        `AUTO` and the current `DistributionStrategy` chooses
        when to synchronize.
      aggregation: Indicates how a distributed variable will be aggregated.
        Accepted values are constants defined in the class
        `tf.VariableAggregation`.
      extra_handle_data: Optional, another resource handle or Tensor with handle
        data to merge with `shape` and `dtype`.
      distribute_strategy: The tf.distribute.Strategy this variable is being
        created inside of.
    """
        with ops.init_scope():
            self._in_graph_mode = not context.executing_eagerly()
            if not self._in_graph_mode:
                raise ValueError(
                    "Can't create UninitializedValue in non-graph mode")
            with ops.name_scope(name, "Variable") as name:
                handle_name = ops.name_from_scope_name(name)
                if self._in_graph_mode:
                    shared_name = handle_name
                    unique_id = shared_name
                else:
                    unique_id = "%s_%d" % (handle_name, ops.uid())
                    shared_name = context.shared_name(unique_id)
                handle = variable_handle_from_shape_and_dtype(
                    shape=shape,
                    dtype=dtype,
                    shared_name=shared_name,
                    name=name,
                    graph_mode=self._in_graph_mode,
                    extra_handle_data=extra_handle_data)
                if not context.executing_eagerly():
                    with ops.name_scope("Read"):
                        # Manually assign reads to the handle's device to avoid log
                        # messages.
                        with ops.device(handle.device):
                            value = gen_resource_variable_ops.read_variable_op(
                                handle, dtype)
                            _maybe_set_handle_data(dtype, handle, value)
                        graph_element = value
                    ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self)
                    # Do *not* add to TRAINABLE_VARIABLES here, even if self._trainable,
                    # because retraining or frozen use of imported SavedModels is
                    # controlled at higher levels of model building.
                else:
                    graph_element = None
        super(UninitializedVariable,
              self).__init__(distribute_strategy=distribute_strategy,
                             shape=shape,
                             dtype=dtype,
                             unique_id=unique_id,
                             handle_name=handle_name,
                             constraint=constraint,
                             handle=handle,
                             graph_element=graph_element,
                             trainable=trainable,
                             synchronization=synchronization,
                             aggregation=aggregation)
    def _init_from_args(
        self,
        initial_value=None,
        trainable=None,
        collections=None,
        caching_device=None,
        name=None,
        dtype=None,
        constraint=None,
        synchronization=None,
        aggregation=None,
        distribute_strategy=None,
        shape=None,
    ):
        """Creates a variable.

        Args:
          initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
            which is the initial value for the Variable. The initial value must have
            a shape specified unless `validate_shape` is set to False. Can also be a
            callable with no argument that returns the initial value when called.
            (Note that initializer functions from init_ops.py must first be bound
             to a shape before being used here.)
          trainable: If `True`, the default, also adds the variable to the graph
            collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
            the default list of variables to use by the `Optimizer` classes.
            Defaults to `True`, unless `synchronization` is set to `ON_READ`, in
            which case it defaults to `False`.
          collections: List of graph collections keys. The new variable is added to
            these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
          caching_device: Optional device string or function describing where the
            Variable should be cached for reading.  Defaults to the Variable's
            device.  If not `None`, caches on another device.  Typical use is to
            cache on the device where the Ops using the Variable reside, to
            deduplicate copying through `Switch` and other conditional statements.
          name: Optional name for the variable. Defaults to `'Variable'` and gets
            uniquified automatically.
          dtype: If set, initial_value will be converted to the given type.
            If None, either the datatype will be kept (if initial_value is
           a Tensor) or float32 will be used (if it is a Python object convertible
           to a Tensor).
          constraint: An optional projection function to be applied to the variable
            after being updated by an `Optimizer` (e.g. used to implement norm
            constraints or value constraints for layer weights). The function must
            take as input the unprojected Tensor representing the value of the
            variable and return the Tensor for the projected value
            (which must have the same shape). Constraints are not safe to
            use when doing asynchronous distributed training.
          synchronization: Indicates when a distributed a variable will be
            aggregated. Accepted values are constants defined in the class
            `tf.VariableSynchronization`. By default the synchronization is set to
            `AUTO` and the current `DistributionStrategy` chooses
            when to synchronize.
          aggregation: Indicates how a distributed variable will be aggregated.
            Accepted values are constants defined in the class
            `tf.VariableAggregation`.
          distribute_strategy: DistributionStrategy under which this variable
            was created.
          shape: (optional) The shape of this variable. If None, the shape of
            `initial_value` will be used. When setting this argument to
            `tf.TensorShape(None)` (representing an unspecified shape), the variable
            can be assigned with values of different shapes.

        Raises:
          ValueError: If the initial value is not specified, or does not have a
            shape and `validate_shape` is `True`.

        @compatibility(eager)
        When Eager Execution is enabled, variables are never added to collections.
        It is not implicitly added to the `GLOBAL_VARIABLES` or
        `TRAINABLE_VARIABLES` collections, and the `collections` argument is
        ignored.
        @end_compatibility
        """
        (
            synchronization,
            aggregation,
            trainable,
        ) = variables.validate_synchronization_aggregation_trainable(
            synchronization, aggregation, trainable, name)
        if initial_value is None:
            raise ValueError("initial_value must be specified.")
        init_from_fn = callable(initial_value)

        if (isinstance(initial_value, ops.Tensor)
                and hasattr(initial_value, "graph")
                and initial_value.graph.building_function):
            raise ValueError(
                "Tensor-typed variable initializers must either be "
                "wrapped in an init_scope or callable "
                "(e.g., `tf.Variable(lambda : "
                "tf.truncated_normal([10, 40]))`) when building "
                "functions. Please file a feature request if this "
                "restriction inconveniences you.")

        if collections is None:
            collections = [ops.GraphKeys.GLOBAL_VARIABLES]
        if not isinstance(collections, (list, tuple, set)):
            raise ValueError(
                "collections argument to Variable constructor must be a list, tuple, "
                "or set. Got %s of type %s" % (collections, type(collections)))
        if constraint is not None and not callable(constraint):
            raise ValueError("The `constraint` argument must be a callable.")

        if isinstance(initial_value, trackable.CheckpointInitialValue):
            self._maybe_initialize_trackable()
            self._update_uid = initial_value.checkpoint_position.restore_uid
            initial_value = initial_value.wrapped_value

        if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
            collections = list(collections) + [
                ops.GraphKeys.TRAINABLE_VARIABLES
            ]
        with ops.init_scope():
            self._in_graph_mode = not context.executing_eagerly()
            with ops.name_scope(
                    name, "TrainableWrapper",
                [] if init_from_fn else [initial_value]) as name:
                # pylint: disable=protected-access
                handle_name = ops.name_from_scope_name(name)
                handle_name = handle_name or "TrainableWrapperHandle"
                if self._in_graph_mode:
                    shared_name = handle_name
                    unique_id = shared_name
                else:
                    # When in eager mode use a uid for the shared_name, to prevent
                    # accidental sharing.
                    unique_id = "%s_%d" % (handle_name, ops.uid())
                    tf_major_version, _, _ = get_tf_version_triple()
                    if int(tf_major_version) >= 2:
                        shared_name = None  # Never shared
                    else:
                        shared_name = context.shared_name()
                # Use attr_scope and device(None) to simulate the behavior of
                # colocate_with when the variable we want to colocate with doesn't
                # yet exist.
                device_context_manager = (ops.device if self._in_graph_mode
                                          else ops.NullContextmanager)
                attr = attr_value_pb2.AttrValue(
                    list=attr_value_pb2.AttrValue.ListValue(
                        s=[compat.as_bytes("loc:@%s" % handle_name)]))
                with ops.get_default_graph()._attr_scope({"_class": attr}):
                    with ops.name_scope("Initializer"), device_context_manager(
                            None):
                        initial_value = ops.convert_to_tensor(
                            initial_value() if init_from_fn else initial_value,
                            name="initial_value",
                            dtype=dtype,
                        )
                    if shape is None:
                        shape = initial_value.shape
                    handle = resource_variable_ops.eager_safe_variable_handle(
                        initial_value=initial_value,
                        shape=None,  # shape,
                        shared_name=shared_name,
                        name=name,
                        graph_mode=self._in_graph_mode,
                    )
                # pylint: disable=protected-access
                if (self._in_graph_mode and initial_value is not None
                        and initial_value.op._get_control_flow_context()
                        is not None):
                    raise ValueError(
                        "Initializer for variable %s is from inside a control-flow "
                        "construct, such as a loop or conditional. When creating a "
                        "variable inside a loop or conditional, use a lambda as the "
                        "initializer." % name)
                # pylint: enable=protected-access
                dtype = initial_value.dtype.base_dtype

                if self._in_graph_mode:
                    with ops.name_scope("IsInitialized"):
                        is_initialized_op = (gen_resource_variable_ops.
                                             var_is_initialized_op(handle))
                    if initial_value is not None:
                        # pylint: disable=g-backslash-continuation
                        with ops.name_scope("Assign") as n, ops.colocate_with(
                                None, ignore_existing=True), ops.device(
                                    handle.device):
                            # pylint: disable=protected-access
                            initializer_op = gen_resource_variable_ops.assign_variable_op(
                                handle,
                                variables.
                                _try_guard_against_uninitialized_dependencies(
                                    name, initial_value),
                                name=n,
                            )
                            # pylint: enable=protected-access
                        # pylint: enable=g-backslash-continuation
                    with ops.name_scope("Read"):
                        # Manually assign reads to the handle's device to avoid log
                        # messages.
                        with ops.device(handle.device):
                            with ops.control_dependencies([
                                    gen_resource_variable_ops.
                                    assign_variable_op(
                                        handle,
                                        self.prefetch_values(),
                                        name="AssignBeforeInitRead",
                                    )
                            ]):
                                value = gen_resource_variable_ops.read_variable_op(
                                    handle, dtype)
                        graph_element = value
                        if caching_device is not None:
                            # Variables may be created in a tf.device() or ops.colocate_with()
                            # context. At the same time, users would expect caching device to
                            # be independent of this context, and/or would not expect the
                            # current device context to be merged with the caching device
                            # spec.  Therefore we reset the colocation stack before creating
                            # the cached value. Note that resetting the colocation stack will
                            # also reset the device stack.
                            with ops.colocate_with(None, ignore_existing=True):
                                with ops.device(caching_device):
                                    cached_value = array_ops.identity(value)
                        else:
                            cached_value = None
                else:
                    gen_resource_variable_ops.assign_variable_op(
                        handle, initial_value)
                    is_initialized_op = None
                    initializer_op = None
                    graph_element = None
                    if caching_device:
                        with ops.device(caching_device):
                            with ops.control_dependencies([
                                    gen_resource_variable_ops.
                                    assign_variable_op(
                                        handle,
                                        self.prefetch_values(),
                                        name="AssignBeforeInitRead",
                                    )
                            ]):
                                cached_value = (
                                    gen_resource_variable_ops.read_variable_op(
                                        handle, dtype))
                    else:
                        cached_value = None
                if not context.executing_eagerly():
                    # Eager variables are only added to collections if they are part of an
                    # eager variable store (otherwise in an interactive session they would
                    # hog memory and cause OOM). This is done in ops/variable_scope.py.
                    ops.add_to_collections(collections, self)
                elif ops.GraphKeys.GLOBAL_STEP in collections:
                    ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
            initial_value = initial_value if self._in_graph_mode else None
            super(resource_variable_ops.ResourceVariable, self).__init__(
                trainable=trainable,
                shape=shape,
                dtype=dtype,
                handle=handle,
                synchronization=synchronization,
                constraint=constraint,
                aggregation=aggregation,
                distribute_strategy=distribute_strategy,
                name=name,
                unique_id=unique_id,
                handle_name=handle_name,
                graph_element=graph_element,
                initial_value=initial_value,
                initializer_op=initializer_op,
                is_initialized_op=is_initialized_op,
                cached_value=cached_value,
            )