Exemplo n.º 1
0
 def _real_mirrored_creator(devices, *args, **kwargs):  # pylint: disable=g-missing-docstring
   index = {}
   for i, d in enumerate(devices):
     with ops.device(d):
       if i > 0:
         # Give replicas meaningful distinct names:
         var0name = index[devices[0]].name.split(":")[0]
         # We append a / to variable names created on towers with id > 0 to
         # ensure that we ignore the name scope and instead use the given
         # name as the absolute name of the variable.
         kwargs["name"] = "%s/replica_%d/" % (var0name, i)
         # Initialize replicas with the same value:
         def initial_value_fn(device=d):
           if context.executing_eagerly():
             init_value = index[devices[0]].value()
             return array_ops.identity(init_value)
           else:
             with ops.device(device):
               init_value = index[devices[0]].initial_value
               return array_ops.identity(init_value)
         kwargs["initial_value"] = initial_value_fn
       with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
         # Don't record operations (e.g. other variable reads) during
         # variable creation.
         with tape.stop_recording():
           v = next_creator(*args, **kwargs)
       assert not isinstance(v, values.DistributedVariable)
       index[d] = v
   return index
Exemplo n.º 2
0
  def _create_variable(self, next_creator, *args, **kwargs):
    """Create a mirrored variable. See `DistributionStrategy.scope`."""
    # Figure out what collections this variable should be added to.
    # We'll add the MirroredVariable to those collections instead.
    collections = kwargs.pop("collections", None)
    if collections is None:
      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    kwargs["collections"] = []

    colocate_with = kwargs.pop("colocate_with", None)
    devices = self._get_devices_from(colocate_with)

    tower_local = kwargs.pop("tower_local_reduce_method", None)
    if tower_local is not None:
      kwargs["trainable"] = False

    # TODO(josh11b,apassos): It would be better if variable initialization
    # was never recorded on the tape instead of having to do this manually
    # here.
    with tape.stop_recording():
      index = {}
      for i, d in enumerate(devices):
        with ops.device(d):
          if i > 0:
            # Give replicas meaningful distinct names:
            var0name = index[devices[0]].name.split(":")[0]
            kwargs["name"] = "%s/replica_%d" % (var0name, i)
            # Initialize replicas with the same value:
            if context.executing_eagerly():
              initial_value = index[devices[0]].value()
            else:
              initial_value = index[devices[0]].initial_value
            kwargs["initial_value"] = array_ops.identity(initial_value)
          with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
            v = next_creator(*args, **kwargs)
          assert not isinstance(v, values.DistributedVariable)
          index[d] = v

      if tower_local is None:
        result = values.MirroredVariable(index, index[devices[0]])
      else:
        result = values.TowerLocalVariable(
            index, index[devices[0]], tower_local)

    if not context.executing_eagerly():
      g = ops.get_default_graph()
      # If "trainable" is True, next_creator() will add the member variables
      # to the TRAINABLE_VARIABLES collection, so we manually remove
      # them and replace with the MirroredVariable. We can't set
      # "trainable" to False for next_creator() since that causes functions
      # like implicit_gradients to skip those variables.
      if kwargs.get("trainable", True):
        collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
        l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
        for v in index.values():
          l.remove(v)
      g.add_to_collections(collections, result)
    return result
Exemplo n.º 3
0
def _create_tpu_mirrored_variable(  # pylint: disable=missing-docstring
    strategy, device_map, logical_device, real_mirrored_creator,
    *args, **kwargs):
  # Figure out what collections this variable should be added to.
  # We'll add the TPUMirroredVariable to those collections instead.
  var_collections = kwargs.pop("collections", None)
  if var_collections is None:
    var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
  kwargs["collections"] = []

  # TODO(jhseu): Should we have different behavior for different
  # synchronization settings?

  # Get aggregation value
  # TODO(jhseu): Support aggregation in a replica context.
  aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
  if aggregation not in [
      vs.VariableAggregation.NONE,
      vs.VariableAggregation.SUM,
      vs.VariableAggregation.MEAN,
      vs.VariableAggregation.ONLY_FIRST_REPLICA,
  ]:
    raise ValueError("Invalid variable aggregation mode: {} for variable: {}"
                     .format(aggregation, kwargs["name"]))

  # Ignore user-specified caching device, not needed for mirrored variables.
  kwargs.pop("caching_device", None)

  # TODO(josh11b,apassos): It would be better if variable initialization
  # was never recorded on the tape instead of having to do this manually
  # here.
  with tape.stop_recording():
    devices = device_map.logical_to_actual_devices(logical_device)
    value_list = real_mirrored_creator(devices, *args, **kwargs)
    result = values.TPUMirroredVariable(
        strategy, device_map, value_list, aggregation,
        logical_device=logical_device)

  if not (context.executing_eagerly() or ops.inside_function()):
    g = ops.get_default_graph()
    # If "trainable" is True, next_creator() will add the member variables
    # to the TRAINABLE_VARIABLES collection, so we manually remove
    # them and replace with the MirroredVariable. We can't set
    # "trainable" to False for next_creator() since that causes functions
    # like implicit_gradients to skip those variables.
    if kwargs.get("trainable", True):
      var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
      l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
      for v in value_list:
        l.remove(v)
    g.add_to_collections(var_collections, result)
  return result
Exemplo n.º 4
0
def compute_gradients(model, images, labels, num_replicas=1):
  with tf.GradientTape() as grad_tape:
    logits = model(images, training=True)
    loss = tf.losses.softmax_cross_entropy(
        logits=logits, onehot_labels=labels)
    tf.contrib.summary.scalar(name='loss', tensor=loss)
    if num_replicas != 1:
      loss /= num_replicas

  # TODO(b/110991947): We can mistakenly trace the gradient call in
  # multi-threaded environment. Explicitly disable recording until
  # this is fixed.
  with tape.stop_recording():
    grads = grad_tape.gradient(loss, model.variables)
  return grads
Exemplo n.º 5
0
  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    if context.in_graph_mode():
      if kwargs:
        raise ValueError(
            "custom_gradient in graph mode doesn't support keyword arguments.")
      name = "CustomGradient-%s" % tf_ops.uid()
      args = [tf_ops.convert_to_tensor(x) for x in args]
      result, grad_fn = f(*args)
      flat_result = nest.flatten(result)
      all_tensors = flat_result + args

      @tf_ops.RegisterGradient(name)
      def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
        gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)]))
        # Need to return one value per input to the IdentityN, so pad the
        # gradients of the inputs of the custom_gradient function with the
        # gradients of the outputs as well.
        return ([None] * len(flat_result)) + gradients

      with tf_ops.get_default_graph().gradient_override_map(
          {"IdentityN": name}):
        all_tensors = array_ops.identity_n(all_tensors)
      return nest.pack_sequence_as(
          structure=result, flat_sequence=all_tensors[:len(flat_result)])

    input_tensors = [x for x in args
                     if isinstance(x, tf_ops.Tensor)]

    with tape.stop_recording():
      result, grad_fn = f(*args, **kwargs)

    # TODO(apassos): naive uses of custom_gradient will not get the correct
    # second derivative this way if they capture any output tensors. Change the
    # signature of custom_gradient.
    def actual_grad_fn(*outputs):
      return grad_fn(*outputs)

    flat_result = nest.flatten(result)
    tape.record_operation(
        f.__name__,
        flat_result,
        input_tensors,
        [],
        actual_grad_fn)
    flat_result = list(flat_result)
    return result
Exemplo n.º 6
0
  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    input_tensors = [x for x in args
                     if isinstance(x, tf_ops.Tensor)]

    with tape.stop_recording():
      result, grad_fn = f(*args, **kwargs)

    # TODO(apassos): naive uses of custom_gradient will not get the correct
    # second derivative this way if they capture any output tensors. Change the
    # signature of custom_gradient.
    def actual_grad_fn(*outputs):
      return grad_fn(*outputs)

    flat_result = nest.flatten(result)
    tape.record_operation(
        flat_result,
        input_tensors,
        [],
        actual_grad_fn)
    flat_result = list(flat_result)
    return result
Exemplo n.º 7
0
  def __call__(self, *args, **kwds):
    """Calls the graph function."""
    if self._created_variables:
      # In this case we have created variables on the first call, so we run the
      # defunned version which is guaranteed to never create variables.
      return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    elif self._stateful_fn is not None:
      # In this case we have not created variables on the first call. So we can
      # run the first trace but we should fail if variables are created.
      results = self._stateful_fn(*args, **kwds)
      if self._created_variables:
        raise ValueError("Creating variables on a non-first call to a function"
                         " decorated with tf.function.")
      return results

    # This is the first call of __call__, so we have to initialize.
    self._initialize(args, kwds)
    if self._lifted_all_initializers and self._lifted_placeholders:
      with ops.init_scope():
        handles, placeholders = zip(*self._lifted_placeholders)
        if context.executing_eagerly():
          lifted_fn = function_lib._EagerDefinedFunction(  # pylint: disable=protected-access
              "initializer" + str(ops.uid()),
              self._lifted_initializer_graph,
              placeholders, [], {})
          with tape.stop_recording():
            lifted_fn.call(context.context(), list(handles))
      return self._stateless_fn(*args, **kwds)
    canon_args, canon_kwds = self._canonicalize_function_inputs(args, kwds)

    if not self._created_variables:
      # If we did not create any variables the trace we have is good enough.
      return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds)  # pylint: disable=protected-access

    def fn_with_cond(*inner_args, **inner_kwds):
      """Conditionally runs initialization if it's needed."""
      condition = True
      for wr in self._created_variables:
        variable = wr()
        if variable is None:
          raise ValueError(
              "A tf.Variable created inside your tf.function has been"
              " garbage-collected. Your code needs to keep Python references"
              " to variables created inside `tf.function`s.\n"
              "\n"
              "A common way to raise this error is to create and return a"
              " variable only referenced inside your function:\n"
              "\n"
              "@tf.function\n"
              "def f():\n"
              "  v = tf.Variable(1.0)\n"
              "  return v\n"
              "\n"
              "v = f()  # Crashes with this error message!\n"
              "\n"
              "The reason this crashes is that @tf.function annotated"
              " function returns a **`tf.Tensor`** with the **value** of the"
              " variable when the function is called rather than the"
              " variable instance itself. As such there is no code holding a"
              " reference to the `v` created inside the function and Python"
              " garbage collects it.\n"
              "\n"
              "The simplest way to fix this issue is to create variables"
              " outside the function and capture them:\n"
              "\n"
              "v = tf.Variable(1.0)\n"
              "\n"
              "@tf.function\n"
              "def f():\n"
              "  return v\n"
              "\n"
              "f()  # <tf.Tensor: ... numpy=1.>\n"
              "v.assign_add(1.)\n"
              "f()  # <tf.Tensor: ... numpy=2.>")
        condition = math_ops.logical_and(
            condition, resource_variable_ops.var_is_initialized_op(
                variable.handle))
      # We want to call stateless_fn if possible because it avoids recomputing
      # potentially expensive initializers.
      return control_flow_ops.cond(
          condition,
          lambda: self._stateless_fn(*inner_args, **inner_kwds),
          functools.partial(self._concrete_stateful_fn._filtered_call,  # pylint: disable=protected-access
                            inner_args, inner_kwds))

    return function_lib.defun(fn_with_cond)(*canon_args, **canon_kwds)
    def _real_mirrored_creator(devices, *args, **kwargs):
      """Creates one MirroredVariable on the current worker."""
      unique_var_name = ops.get_default_graph().unique_name(
          kwargs["name"], mark_as_used=False).rstrip("/")
      # pylint: disable=protected-access
      collective_instance_key = self._collective_keys.get_instance_key(
          key_id=unique_var_name)
      # Only the first device participles in the broadcast of initial values.
      group_key = self._collective_keys.get_group_key([devices[0]])
      group_size = self._num_workers
      if "initial_value" not in kwargs:
        raise ValueError("Initial value must be specified.")
      initial_value = kwargs["initial_value"]
      if callable(initial_value):
        initial_value_fn = initial_value
      else:
        initial_value_fn = lambda: initial_value

      value_list = []
      for i, d in enumerate(devices):
        with ops.init_scope(), ops.device(d):
          if i == 0:
            # The initial value fn makes sure variables all initialized to
            # same values. The first device of the chief worker will send their
            # variable values to other workers.
            def _overridden_initial_value_fn(device=d, index=i):  # pylint: disable=g-missing-docstring
              with ops.device(device):
                initial_value = initial_value_fn()
                assert not callable(initial_value)
                initial_value = ops.convert_to_tensor(initial_value)

                assert index == 0, index
                if self._num_workers > 1:
                  if self._is_chief:
                    bcast_send = collective_ops.broadcast_send(
                        initial_value, initial_value.shape, initial_value.dtype,
                        group_size, group_key, collective_instance_key)
                    with ops.control_dependencies([bcast_send]):
                      return array_ops.identity(initial_value)
                  else:
                    return collective_ops.broadcast_recv(
                        initial_value.shape, initial_value.dtype, group_size,
                        group_key, collective_instance_key)
                return initial_value
          else:
            # Give replicas meaningful distinct names:
            var0name = value_list[0].name.split(":")[0]
            # We append a / to variable names created on replicas with id > 0 to
            # ensure that we ignore the name scope and instead use the given
            # name as the absolute name of the variable.
            kwargs["name"] = "%s/replica_%d/" % (var0name, i)

            # Variables on non-first replica get initial values from the
            # variables created on the first device of each worker.
            def _overridden_initial_value_fn(device=d, index=i):
              assert index > 0
              with ops.device(device):
                if context.executing_eagerly():
                  return array_ops.identity(value_list[0].value())
                else:
                  return array_ops.identity(value_list[0].initial_value)

          kwargs["initial_value"] = _overridden_initial_value_fn
          with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
            # Don't record operations (e.g. other variable reads) during
            # variable creation.
            with tape.stop_recording():
              v = next_creator(*args, **kwargs)

          if i == 0:
            actual_var_name = v.name.split(":")[0]
            assert unique_var_name == actual_var_name, "%r vs %r" % (
                unique_var_name, actual_var_name)
          assert not isinstance(v, values.DistributedVariable)
          value_list.append(v)
      return value_list
Exemplo n.º 9
0
    def inner(*args, **kwargs):
        """Inner function closure for calculating gradients."""
        current_var_scope = variable_scope.get_variable_scope()
        with tape_lib.stop_recording():
            result = f(*args, **kwargs)

        def grad_wrapper(*wrapper_args, variables=None):
            """Wrapper function to accomodate lack of kwargs in graph mode custom_gradient."""
            @custom_gradient
            def inner_recompute_grad(*dresult):
                """Nested custom gradient function for computing grads in reverse and forward mode autodiff."""
                # Gradient calculation for reverse mode autodiff.
                with backprop.GradientTape() as t:
                    id_args = nest.map_structure(gen_array_ops.identity, args)
                    # Tuple `dresult` should contain at least one tensor.
                    assert len(dresult) >= 1

                    if not context.executing_eagerly():
                        # XLA doesn't respect `tf.control_dependencies`. The code block
                        # below manually adds a data dependency to `dresult` to ensure
                        # recomputation of `f(*args, **kwargs)` happens after `dresult`.

                        # This works even if `dresult[0]` is a size 0 tensor as reduce_max
                        # of a size 0 tensor returns -inf. Use reshape here to avoid reading
                        # the entire `dresult[0]`.
                        elem = math_ops.reduce_max(
                            array_ops.reshape(dresult[0], [-1])[:1])
                        # Cast elem to bool in case elem is NaN.
                        elem_bool = math_ops.cast(elem, dtypes.bool)
                        dresult_dep = array_ops.where_v2(
                            elem_bool == elem_bool, 0., float("nan"))  # pylint: disable=comparison-with-itself
                        id_args = nest.map_structure(
                            lambda x: x + math_ops.cast(dresult_dep, x.dtype),
                            id_args)

                    t.watch(id_args)
                    if variables is not None:
                        t.watch(variables)
                    with variable_scope.variable_scope(current_var_scope):
                        recomputed_result = f(*id_args, **kwargs)
                kw_vars = []
                if variables is not None:
                    kw_vars = list(variables)
                grads = t.gradient(
                    recomputed_result,
                    list(id_args) + kw_vars,
                    output_gradients=dresult,
                    unconnected_gradients=UnconnectedGradients.ZERO)

                def transpose(*t_args, **t_kwargs):
                    """Gradient function calculation for forward mode autodiff."""
                    # Just throw an error since gradients / activations are not stored on
                    # tape for recompute.
                    raise NotImplementedError(
                        "recompute_grad tried to transpose grad of {}. "
                        "Consider not using recompute_grad in forward mode"
                        "autodiff".format(f.__name__))

                return (grads[:len(id_args)], grads[len(id_args):]), transpose

            return inner_recompute_grad(*wrapper_args)

        return result, grad_wrapper
Exemplo n.º 10
0
def _create_mirrored_variable(
        strategy,
        device_map,
        logical_device,  # pylint: disable=missing-docstring
        real_mirrored_creator,
        *args,
        **kwargs):
    # Figure out what collections this variable should be added to.
    # We'll add the MirroredVariable to those collections instead.
    collections = kwargs.pop("collections", None)
    if collections is None:
        collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    kwargs["collections"] = []

    # Get synchronization value
    synchronization = kwargs.get(
        "synchronization", variable_scope.VariableSynchronization.ON_WRITE)
    if synchronization == variable_scope.VariableSynchronization.NONE:
        raise ValueError(
            "`NONE` variable synchronization mode is not "
            "supported with `Mirrored` distribution strategy. Please"
            " change the `synchronization` for variable: " + kwargs["name"])
    elif synchronization == variable_scope.VariableSynchronization.ON_READ:
        # Variables that are to be synced on read are replica local.
        is_replica_local = True
        kwargs["trainable"] = False
    elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE
          or synchronization == variable_scope.VariableSynchronization.AUTO):
        # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
        is_replica_local = False
    else:
        raise ValueError("Invalid variable synchronization mode: " +
                         synchronization + " for variable: " + kwargs["name"])

    # Get aggregation value
    aggregation = kwargs.pop("aggregation",
                             variable_scope.VariableAggregation.NONE)
    if aggregation not in (
            variable_scope.VariableAggregation.NONE,
            variable_scope.VariableAggregation.SUM,
            variable_scope.VariableAggregation.MEAN,
            variable_scope.VariableAggregation.ONLY_FIRST_REPLICA):
        raise ValueError("Invalid variable aggregation mode: " + aggregation +
                         " for variable: " + kwargs["name"])

    # Ignore user-specified caching device, not needed for mirrored variables.
    kwargs.pop("caching_device", None)

    # TODO(josh11b,apassos): It would be better if variable initialization
    # was never recorded on the tape instead of having to do this manually
    # here.
    with tape.stop_recording():
        devices = device_map.logical_to_actual_devices(logical_device)
        value_list = real_mirrored_creator(devices, *args, **kwargs)

        if is_replica_local:
            result = values.ReplicaLocalVariable(strategy,
                                                 device_map,
                                                 value_list,
                                                 aggregation,
                                                 logical_device=logical_device)
        else:
            result = values.MirroredVariable(strategy,
                                             device_map,
                                             value_list,
                                             aggregation,
                                             logical_device=logical_device)

    # Add the wrapped variable to the requested collections.
    # The handling of eager mode and the global step matches
    # ResourceVariable._init_from_args().
    if not context.executing_eagerly():
        g = ops.get_default_graph()
        # If "trainable" is True, next_creator() will add the member variables
        # to the TRAINABLE_VARIABLES collection, so we manually remove
        # them and replace with the MirroredVariable. We can't set
        # "trainable" to False for next_creator() since that causes functions
        # like implicit_gradients to skip those variables.
        if kwargs.get("trainable", True):
            collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
            l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
            for v in value_list:
                if v in l:
                    l.remove(v)
        g.add_to_collections(collections, result)
    elif ops.GraphKeys.GLOBAL_STEP in collections:
        ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)

    return result
Exemplo n.º 11
0
  def _create_variable(self, next_creator, *args, **kwargs):
    """Create a mirrored variable. See `DistributionStrategy.scope`."""
    # Figure out what collections this variable should be added to.
    # We'll add the MirroredVariable to those collections instead.
    collections = kwargs.pop("collections", None)
    if collections is None:
      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    kwargs["collections"] = []

    colocate_with = kwargs.pop("colocate_with", None)
    devices = self._get_devices_from(colocate_with)

    tower_local = kwargs.pop("tower_local_reduce_method", None)
    if tower_local is not None:
      kwargs["trainable"] = False

    # Ignore user-specified caching device, not needed for mirrored variables.
    kwargs.pop("caching_device", None)

    # TODO(josh11b,apassos): It would be better if variable initialization
    # was never recorded on the tape instead of having to do this manually
    # here.
    with tape.stop_recording():
      index = {}
      for i, d in enumerate(devices):
        with ops.device(d):
          if i > 0:
            # Give replicas meaningful distinct names:
            var0name = index[devices[0]].name.split(":")[0]
            # We append a / to variable names created on towers with id > 0 to
            # ensure that we ignore the name scope and instead use the given
            # name as the absolute name of the variable.
            kwargs["name"] = "%s/replica_%d/" % (var0name, i)
            # Initialize replicas with the same value:
            if context.executing_eagerly():
              kwargs["initial_value"] = array_ops.identity(
                  index[devices[0]].value())
            else:
              def initial_value_fn(device=d):
                with ops.device(device):
                  return array_ops.identity(index[devices[0]].initial_value)
              kwargs["initial_value"] = initial_value_fn
          with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
            v = next_creator(*args, **kwargs)
          assert not isinstance(v, values.DistributedVariable)
          index[d] = v

      if tower_local is None:
        result = values.MirroredVariable(index, index[devices[0]])
      else:
        result = values.TowerLocalVariable(
            index, index[devices[0]], tower_local)

    if not context.executing_eagerly():
      g = ops.get_default_graph()
      # If "trainable" is True, next_creator() will add the member variables
      # to the TRAINABLE_VARIABLES collection, so we manually remove
      # them and replace with the MirroredVariable. We can't set
      # "trainable" to False for next_creator() since that causes functions
      # like implicit_gradients to skip those variables.
      if kwargs.get("trainable", True):
        collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
        l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
        for v in index.values():
          l.remove(v)
      g.add_to_collections(collections, result)
    return result
def create_mirrored_variable(  # pylint: disable=missing-docstring
        strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls,
        **kwargs):
    # Figure out what collections this variable should be added to.
    # We'll add the MirroredVariable to those collections instead.
    var_collections = kwargs.pop("collections", None)
    if var_collections is None:
        var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    kwargs["collections"] = []

    synchronization = kwargs.get("synchronization",
                                 vs.VariableSynchronization.ON_WRITE)

    if synchronization == vs.VariableSynchronization.NONE:
        raise ValueError(
            "`NONE` variable synchronization mode is not supported with `Mirrored` "
            "distribution strategy. Please change the `synchronization` for "
            "variable: " + str(kwargs["name"]))
    elif synchronization == vs.VariableSynchronization.ON_READ:
        is_sync_on_read = True
    elif synchronization in (vs.VariableSynchronization.ON_WRITE,
                             vs.VariableSynchronization.AUTO):
        # `AUTO` synchronization defaults to `ON_WRITE`.
        is_sync_on_read = False
    else:
        raise ValueError(
            "Invalid variable synchronization mode: %s for variable: %s" %
            (synchronization, kwargs["name"]))

    aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)

    if aggregation not in (vs.VariableAggregation.NONE,
                           vs.VariableAggregation.SUM,
                           vs.VariableAggregation.MEAN,
                           vs.VariableAggregation.ONLY_FIRST_REPLICA):
        raise ValueError(
            "Invalid variable aggregation mode: %s for variable: %s" %
            (aggregation, kwargs["name"]))

    # Ignore user-specified caching device, not needed for mirrored variables.
    kwargs.pop("caching_device", None)

    # TODO(josh11b,apassos): It would be better if variable initialization
    # was never recorded on the tape instead of having to do this manually
    # here.
    with tape.stop_recording():
        value_list = real_mirrored_creator(**kwargs)
        var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
        result = var_cls(strategy, value_list, aggregation)
        # Install the created DistributedVariable as _distributed_container property
        # of the underlying variables, to make it easy to map back to the container.
        for v in result.values:
            # Hold a strong reference to avoid the container from being GC-ed. After
            # v = v.assign(), the user code may no longer holds references to the
            # original container, since v.assign() returns a new DistributedVariable.
            v._distributed_container = result  # pylint: disable=protected-access

    # Add the wrapped variable to the requested collections.
    # The handling of eager mode and the global step matches
    # ResourceVariable._init_from_args().
    if not context.executing_eagerly():
        g = ops.get_default_graph()
        # If "trainable" is True, next_creator() will add the member variables
        # to the TRAINABLE_VARIABLES collection, so we manually remove
        # them and replace with the MirroredVariable. We can't set
        # "trainable" to False for next_creator() since that causes functions
        # like implicit_gradients to skip those variables.
        if kwargs.get("trainable", True):
            var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
            l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
            for value in value_list:
                for i, trainable_variable in enumerate(l):
                    if value is trainable_variable:
                        del l[i]
                        break

        g.add_to_collections(var_collections, result)
    elif ops.GraphKeys.GLOBAL_STEP in var_collections:
        ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)

    return result
        def _real_mirrored_creator(devices, *args, **kwargs):
            """Creates one MirroredVariable on the current worker."""
            unique_var_name = ops.get_default_graph().unique_name(
                kwargs["name"], mark_as_used=False).rstrip("/")
            # pylint: disable=protected-access
            collective_instance_key = self._collective_keys.get_instance_key(
                key_id=unique_var_name)
            # Only the first device participles in the broadcast of initial values.
            group_key = self._collective_keys.get_group_key([devices[0]])
            group_size = self._num_workers
            if "initial_value" not in kwargs:
                raise ValueError("Initial value must be specified.")
            initial_value = kwargs["initial_value"]
            if callable(initial_value):
                initial_value_fn = initial_value
            else:
                initial_value_fn = lambda: initial_value

            value_list = []
            for i, d in enumerate(devices):
                with ops.init_scope(), ops.device(d):
                    if i == 0:
                        # The initial value fn makes sure variables all initialized to
                        # same values. The first device of the chief worker will send their
                        # variable values to other workers.
                        def _overridden_initial_value_fn(device=d, index=i):  # pylint: disable=g-missing-docstring
                            with ops.device(device):
                                initial_value = initial_value_fn()
                                assert not callable(initial_value)
                                initial_value = ops.convert_to_tensor(
                                    initial_value)

                                assert index == 0, index
                                if self._num_workers > 1:
                                    if self._is_chief:
                                        bcast_send = collective_ops.broadcast_send(
                                            initial_value, initial_value.shape,
                                            initial_value.dtype, group_size,
                                            group_key, collective_instance_key)
                                        with ops.control_dependencies(
                                            [bcast_send]):
                                            return array_ops.identity(
                                                initial_value)
                                    else:
                                        return collective_ops.broadcast_recv(
                                            initial_value.shape,
                                            initial_value.dtype, group_size,
                                            group_key, collective_instance_key)
                                return initial_value
                    else:
                        # Give replicas meaningful distinct names:
                        var0name = value_list[0].name.split(":")[0]
                        # We append a / to variable names created on replicas with id > 0 to
                        # ensure that we ignore the name scope and instead use the given
                        # name as the absolute name of the variable.
                        kwargs["name"] = "%s/replica_%d/" % (var0name, i)

                        # Variables on non-first replica get initial values from the
                        # variables created on the first device of each worker.
                        def _overridden_initial_value_fn(device=d, index=i):
                            assert index > 0
                            with ops.device(device):
                                if context.executing_eagerly():
                                    return array_ops.identity(
                                        value_list[0].value())
                                else:
                                    return array_ops.identity(
                                        value_list[0].initial_value)

                    kwargs["initial_value"] = _overridden_initial_value_fn
                    with context.context().device_policy(
                            context.DEVICE_PLACEMENT_SILENT):
                        # Don't record operations (e.g. other variable reads) during
                        # variable creation.
                        with tape.stop_recording():
                            v = next_creator(*args, **kwargs)

                    if i == 0:
                        actual_var_name = v.name.split(":")[0]
                        assert unique_var_name == actual_var_name, "%r vs %r" % (
                            unique_var_name, actual_var_name)
                    assert not isinstance(v, values.DistributedVariable)
                    value_list.append(v)
            return value_list
Exemplo n.º 14
0
def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping,
                             policy_mapping, **kwargs):
    """Create distributed variables with given synchronization and aggregation."""
    # Figure out what collections this variable should be added to.
    # We'll add the MirroredVariable to those collections instead.
    var_collections = kwargs.pop("collections", None)
    if var_collections is None:
        var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    kwargs["collections"] = []

    synchronization = _validate_synchronization(kwargs)
    # Update synchronization in kwargs in case it's AUTO, which is converted to
    # ON_WRITE.
    kwargs["synchronization"] = synchronization
    aggregation = _validate_aggregation(kwargs)
    use_var_policy = getattr(strategy.extended, "_use_var_policy", False)

    # Ignore user-specified caching device, not needed for mirrored variables.
    kwargs.pop("caching_device", None)

    # TODO(josh11b,apassos): It would be better if variable initialization
    # was never recorded on the tape instead of having to do this manually
    # here.
    with tape.stop_recording():
        value_list = real_mirrored_creator(**kwargs)
        if use_var_policy:
            var_policy_cls = policy_mapping.get(synchronization)
            var_policy = var_policy_cls(aggregation=aggregation)
            var_cls = class_mapping.get("VariableClass")
            result = var_cls(strategy,
                             value_list,
                             aggregation,
                             var_policy=var_policy)
        else:
            var_cls = class_mapping.get(synchronization)
            result = var_cls(strategy, value_list, aggregation)

    # Add the wrapped variable to the requested collections.
    # The handling of eager mode and the global step matches
    # ResourceVariable._init_from_args().
    if not context.executing_eagerly():
        g = ops.get_default_graph()
        # If "trainable" is True, next_creator() will add the member variables
        # to the TRAINABLE_VARIABLES collection, so we manually remove
        # them and replace with the MirroredVariable. We can't set
        # "trainable" to False for next_creator() since that causes functions
        # like implicit_gradients to skip those variables.
        if kwargs.get("trainable", True):
            var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
            l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
            for value in value_list:
                for i, trainable_variable in enumerate(l):
                    if value is trainable_variable:
                        del l[i]
                        break

        g.add_to_collections(var_collections, result)
    elif ops.GraphKeys.GLOBAL_STEP in var_collections:
        ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)

    return result
Exemplo n.º 15
0
  def _create_variable(self, next_creator, *args, **kwargs):
    """Create a mirrored variable. See `DistributionStrategy.scope`."""
    # Figure out what collections this variable should be added to.
    # We'll add the MirroredVariable to those collections instead.
    collections = kwargs.pop("collections", None)
    if collections is None:
      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    kwargs["collections"] = []

    colocate_with = kwargs.pop("colocate_with", None)
    devices = self._get_devices_from(colocate_with)

    # Get synchronization value
    synchronization = kwargs.get(
        "synchronization", variable_scope.VariableSynchronization.ON_WRITE)
    if synchronization == variable_scope.VariableSynchronization.NONE:
      raise ValueError("`NONE` variable synchronization mode is not "
                       "supported with `Mirrored` distribution strategy. Please"
                       " change the `synchronization` for variable: " +
                       kwargs["name"])
    elif synchronization == variable_scope.VariableSynchronization.ON_READ:
      # Variables that are to be synced on read are tower local.
      is_tower_local = True
      kwargs["trainable"] = False
    elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
          synchronization == variable_scope.VariableSynchronization.AUTO):
      # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
      is_tower_local = False
    else:
      raise ValueError("Invalid variable synchronization mode: " +
                       synchronization + " for variable: " + kwargs["name"])

    # Get aggregation value
    aggregation = kwargs.pop("aggregation",
                             variable_scope.VariableAggregation.NONE)
    if aggregation not in [
        variable_scope.VariableAggregation.NONE,
        variable_scope.VariableAggregation.SUM,
        variable_scope.VariableAggregation.MEAN
    ]:
      raise ValueError("Invalid variable aggregation mode: " + aggregation +
                       " for variable: " + kwargs["name"])

    # Ignore user-specified caching device, not needed for mirrored variables.
    kwargs.pop("caching_device", None)

    # TODO(josh11b,apassos): It would be better if variable initialization
    # was never recorded on the tape instead of having to do this manually
    # here.
    with tape.stop_recording():
      index = {}
      for i, d in enumerate(devices):
        with ops.device(d):
          if i > 0:
            # Give replicas meaningful distinct names:
            var0name = index[devices[0]].name.split(":")[0]
            # We append a / to variable names created on towers with id > 0 to
            # ensure that we ignore the name scope and instead use the given
            # name as the absolute name of the variable.
            kwargs["name"] = "%s/replica_%d/" % (var0name, i)
            # Initialize replicas with the same value:
            if context.executing_eagerly():
              kwargs["initial_value"] = array_ops.identity(
                  index[devices[0]].value())
            else:
              def initial_value_fn(device=d):
                with ops.device(device):
                  return array_ops.identity(index[devices[0]].initial_value)
              kwargs["initial_value"] = initial_value_fn
          with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
            v = next_creator(*args, **kwargs)
          assert not isinstance(v, values.DistributedVariable)
          index[d] = v

      if is_tower_local:
        result = values.TowerLocalVariable(index, index[devices[0]],
                                           aggregation)
      else:
        result = values.MirroredVariable(index, index[devices[0]], aggregation)

    if not context.executing_eagerly():
      g = ops.get_default_graph()
      # If "trainable" is True, next_creator() will add the member variables
      # to the TRAINABLE_VARIABLES collection, so we manually remove
      # them and replace with the MirroredVariable. We can't set
      # "trainable" to False for next_creator() since that causes functions
      # like implicit_gradients to skip those variables.
      if kwargs.get("trainable", True):
        collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
        l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
        for v in index.values():
          l.remove(v)
      g.add_to_collections(collections, result)
    return result
Exemplo n.º 16
0
    def __call__(self, *args, **kwds):
        """Calls the graph function."""
        if self._created_variables:
            # In this case we have created variables on the first call, so we run the
            # defunned version which is guaranteed to never create variables.
            return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
        elif self._stateful_fn is not None:
            # In this case we have not created variables on the first call. So we can
            # run the first trace but we should fail if variables are created.
            results = self._stateful_fn(*args, **kwds)
            if self._created_variables:
                raise ValueError(
                    "Creating variables on a non-first call to a function"
                    " decorated with tf.function.")
            return results

        # This is the first call of __call__, so we have to initialize.
        self._initialize(args, kwds)
        if self._lifted_all_initializers and self._lifted_placeholders:
            with ops.init_scope():
                handles, placeholders = zip(*self._lifted_placeholders)
                if context.executing_eagerly():
                    lifted_fn = function_lib._EagerDefinedFunction(  # pylint: disable=protected-access
                        "initializer" + str(ops.uid()),
                        self._lifted_initializer_graph, placeholders, [], {})
                    with tape.stop_recording():
                        lifted_fn.call(context.context(), list(handles))
            return self._stateless_fn(*args, **kwds)
        canon_args, canon_kwds = self._canonicalize_function_inputs(args, kwds)

        if not self._created_variables:
            # If we did not create any variables the trace we have is good enough.
            return self._concrete_stateful_fn._filtered_call(
                canon_args, canon_kwds)  # pylint: disable=protected-access

        def fn_with_cond(*inner_args, **inner_kwds):
            """Conditionally runs initialization if it's needed."""
            condition = True
            for wr in self._created_variables:
                variable = wr()
                if variable is None:
                    raise ValueError(
                        "A tf.Variable created inside your tf.function has been"
                        " garbage-collected. Your code needs to keep Python references"
                        " to variables created inside `tf.function`s.\n"
                        "\n"
                        "A common way to raise this error is to create and return a"
                        " variable only referenced inside your function:\n"
                        "\n"
                        "@tf.function\n"
                        "def f():\n"
                        "  v = tf.Variable(1.0)\n"
                        "  return v\n"
                        "\n"
                        "v = f()  # Crashes with this error message!\n"
                        "\n"
                        "The reason this crashes is that @tf.function annotated"
                        " function returns a **`tf.Tensor`** with the **value** of the"
                        " variable when the function is called rather than the"
                        " variable instance itself. As such there is no code holding a"
                        " reference to the `v` created inside the function and Python"
                        " garbage collects it.\n"
                        "\n"
                        "The simplest way to fix this issue is to create variables"
                        " outside the function and capture them:\n"
                        "\n"
                        "v = tf.Variable(1.0)\n"
                        "\n"
                        "@tf.function\n"
                        "def f():\n"
                        "  return v\n"
                        "\n"
                        "f()  # <tf.Tensor: ... numpy=1.>\n"
                        "v.assign_add(1.)\n"
                        "f()  # <tf.Tensor: ... numpy=2.>")
                condition = math_ops.logical_and(
                    condition,
                    resource_variable_ops.var_is_initialized_op(
                        variable.handle))
            # We want to call stateless_fn if possible because it avoids recomputing
            # potentially expensive initializers.
            return control_flow_ops.cond(
                condition,
                lambda: self._stateless_fn(*inner_args, **inner_kwds),
                functools.partial(
                    self._concrete_stateful_fn._filtered_call,  # pylint: disable=protected-access
                    inner_args,
                    inner_kwds))

        return function_lib.defun(fn_with_cond)(*canon_args, **canon_kwds)
Exemplo n.º 17
0
def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):  # pylint: disable=g-missing-docstring
  # Figure out what collections this variable should be added to.
  # We'll add the MirroredVariable to those collections instead.
  collections = kwargs.pop("collections", None)
  if collections is None:
    collections = [ops.GraphKeys.GLOBAL_VARIABLES]
  kwargs["collections"] = []

  # Get synchronization value
  synchronization = kwargs.get("synchronization",
                               variable_scope.VariableSynchronization.ON_WRITE)
  if synchronization == variable_scope.VariableSynchronization.NONE:
    raise ValueError("`NONE` variable synchronization mode is not "
                     "supported with `Mirrored` distribution strategy. Please"
                     " change the `synchronization` for variable: " +
                     kwargs["name"])
  elif synchronization == variable_scope.VariableSynchronization.ON_READ:
    # Variables that are to be synced on read are tower local.
    is_tower_local = True
    kwargs["trainable"] = False
  elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
        synchronization == variable_scope.VariableSynchronization.AUTO):
    # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
    is_tower_local = False
  else:
    raise ValueError("Invalid variable synchronization mode: " +
                     synchronization + " for variable: " + kwargs["name"])

  # Get aggregation value
  aggregation = kwargs.pop("aggregation",
                           variable_scope.VariableAggregation.NONE)
  if aggregation not in (
      variable_scope.VariableAggregation.NONE,
      variable_scope.VariableAggregation.SUM,
      variable_scope.VariableAggregation.MEAN,
      variable_scope.VariableAggregation.ONLY_FIRST_TOWER
  ):
    raise ValueError("Invalid variable aggregation mode: " + aggregation +
                     " for variable: " + kwargs["name"])

  # Ignore user-specified caching device, not needed for mirrored variables.
  kwargs.pop("caching_device", None)

  # TODO(josh11b,apassos): It would be better if variable initialization
  # was never recorded on the tape instead of having to do this manually
  # here.
  with tape.stop_recording():
    index = real_mirrored_creator(devices, *args, **kwargs)

    if is_tower_local:
      result = values.TowerLocalVariable(index, index[devices[0]], aggregation)
    else:
      result = values.MirroredVariable(index, index[devices[0]], aggregation)

  # Add the wrapped variable to the requested collections.
  # The handling of eager mode and the global step matches
  # ResourceVariable._init_from_args().
  if not context.executing_eagerly():
    g = ops.get_default_graph()
    # If "trainable" is True, next_creator() will add the member variables
    # to the TRAINABLE_VARIABLES collection, so we manually remove
    # them and replace with the MirroredVariable. We can't set
    # "trainable" to False for next_creator() since that causes functions
    # like implicit_gradients to skip those variables.
    if kwargs.get("trainable", True):
      collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
      l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
      for v in index.values():
        l.remove(v)
    g.add_to_collections(collections, result)
  elif ops.GraphKeys.GLOBAL_STEP in collections:
    ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)

  return result
Exemplo n.º 18
0
    def _create_variable(self, next_creator, *args, **kwargs):
        """Create a mirrored variable. See `DistributionStrategy.scope`."""
        # Figure out what collections this variable should be added to.
        # We'll add the MirroredVariable to those collections instead.
        collections = kwargs.pop("collections", None)
        if collections is None:
            collections = [ops.GraphKeys.GLOBAL_VARIABLES]
        kwargs["collections"] = []

        colocate_with = kwargs.pop("colocate_with", None)
        devices = self._get_devices_from(colocate_with)

        # Get synchronization value
        synchronization = kwargs.get(
            "synchronization", variable_scope.VariableSynchronization.ON_WRITE)
        if synchronization == variable_scope.VariableSynchronization.NONE:
            raise ValueError(
                "`NONE` variable synchronization mode is not "
                "supported with `Mirrored` distribution strategy. Please"
                " change the `synchronization` for variable: " +
                kwargs["name"])
        elif synchronization == variable_scope.VariableSynchronization.ON_READ:
            # Variables that are to be synced on read are tower local.
            is_tower_local = True
            kwargs["trainable"] = False
        elif (synchronization
              == variable_scope.VariableSynchronization.ON_WRITE or
              synchronization == variable_scope.VariableSynchronization.AUTO):
            # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
            is_tower_local = False
        else:
            raise ValueError("Invalid variable synchronization mode: " +
                             synchronization + " for variable: " +
                             kwargs["name"])

        # Get aggregation value
        aggregation = kwargs.pop("aggregation",
                                 variable_scope.VariableAggregation.NONE)
        if aggregation not in [
                variable_scope.VariableAggregation.NONE,
                variable_scope.VariableAggregation.SUM,
                variable_scope.VariableAggregation.MEAN
        ]:
            raise ValueError("Invalid variable aggregation mode: " +
                             aggregation + " for variable: " + kwargs["name"])

        # Ignore user-specified caching device, not needed for mirrored variables.
        kwargs.pop("caching_device", None)

        # TODO(josh11b,apassos): It would be better if variable initialization
        # was never recorded on the tape instead of having to do this manually
        # here.
        with tape.stop_recording():
            index = {}
            for i, d in enumerate(devices):
                with ops.device(d):
                    if i > 0:
                        # Give replicas meaningful distinct names:
                        var0name = index[devices[0]].name.split(":")[0]
                        # We append a / to variable names created on towers with id > 0 to
                        # ensure that we ignore the name scope and instead use the given
                        # name as the absolute name of the variable.
                        kwargs["name"] = "%s/replica_%d/" % (var0name, i)
                        # Initialize replicas with the same value:
                        if context.executing_eagerly():
                            kwargs["initial_value"] = array_ops.identity(
                                index[devices[0]].value())
                        else:

                            def initial_value_fn(device=d):
                                with ops.device(device):
                                    return array_ops.identity(
                                        index[devices[0]].initial_value)

                            kwargs["initial_value"] = initial_value_fn
                    with context.context().device_policy(
                            context.DEVICE_PLACEMENT_SILENT):
                        v = next_creator(*args, **kwargs)
                    assert not isinstance(v, values.DistributedVariable)
                    index[d] = v

            if is_tower_local:
                result = values.TowerLocalVariable(index, index[devices[0]],
                                                   aggregation)
            else:
                result = values.MirroredVariable(index, index[devices[0]],
                                                 aggregation)

        if not context.executing_eagerly():
            g = ops.get_default_graph()
            # If "trainable" is True, next_creator() will add the member variables
            # to the TRAINABLE_VARIABLES collection, so we manually remove
            # them and replace with the MirroredVariable. We can't set
            # "trainable" to False for next_creator() since that causes functions
            # like implicit_gradients to skip those variables.
            if kwargs.get("trainable", True):
                collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
                l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
                for v in index.values():
                    l.remove(v)
            g.add_to_collections(collections, result)
        return result