Example #1
0
 def run(self):
   self.should_run.wait()
   self.should_run.clear()
   try:
     if self.coord.should_stop():
       return
     self.restore_thread_local_context_fields()
     # TODO(josh11b): Use current logical device instead of 0 here.
     with self.coord.stop_on_exception(), \
         _enter_graph(self._init_graph, self._init_in_eager), \
         _enter_graph(self.graph, self.in_eager,
                      self._variable_creator_stack), \
         context.device_policy(self.context_device_policy), \
         MirroredReplicaContext(self.distribution, constant_op.constant(
             self.replica_id, dtypes.int32)), \
         ops.device(self.device_map.logical_to_actual_devices(0)[
             self.replica_id]), \
         ops.name_scope(self._name_scope), \
         variable_scope.variable_scope(
             self._var_scope, reuse=self.replica_id > 0), \
         variable_scope.variable_creator_scope(self.variable_creator_fn):
       self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
       self.done = True
   finally:
     self.has_paused.set()
Example #2
0
def notify_about_variables(callback):
  """Calls `callback(var)` for all `tf.{Variable,get_variable}` results.

  Callback should not modify the variable passed in. Use cases that require
  variables to be modified should use `variable_creator_scope` directly and sit
  within the variable creator stack.

  >>> variables = []
  >>> with notify_about_variables(variables.append):
  ...   v = tf.Variable(1.0, name='v')
  ...   w = tf.get_variable('w', [])
  >>> assert variables == [v, w]

  Args:
    callback: a callable taking a single argument which is a tf.Variable.

  Yields:
    `None` - used for contextmanager API.
  """
  def _tracking_creator(getter, **kwargs):
    v = getter(**kwargs)
    callback(v)
    return v

  with variable_scope_ops.variable_creator_scope(_tracking_creator):
    yield
  def testOptimizerInsideModelFn(self, distribution, optimizer_fn):
    created_variables = []
    trainable_variables = []

    def appending_creator(next_creator, *args, **kwargs):
      v = next_creator(*args, **kwargs)
      created_variables.append(v.name)
      if "trainable" in kwargs and kwargs["trainable"]:
        trainable_variables.append(v.name)
      return v

    # Creator scope needs to be set before it's used inside
    # `distribution.scope`.
    with variable_scope.variable_creator_scope(
        appending_creator), distribution.scope():
      model_fn, dataset_fn, _ = minimize_loss_example(
          optimizer_fn,
          use_bias=True,
          use_callable_loss=True,
          create_optimizer_inside_model_fn=True)

      def step_fn(ctx, inputs):
        del ctx  # Unused
        return distribution.group(
            distribution.extended.call_for_each_replica(
                model_fn, args=(inputs,)))

      iterator = self._get_iterator(distribution, dataset_fn)

      def run_step():
        return distribution.extended.experimental_run_steps_on_iterator(
            step_fn, iterator, iterations=1).run_op

      if not context.executing_eagerly():
        with self.cached_session() as sess:
          run_step = sess.make_callable(run_step())
      self.evaluate(variables_lib.global_variables_initializer())
      run_step()

      def get_expected_variables(optimizer_fn, num_parameter_devices):
        optimizer = optimizer_fn()
        name = optimizer._name

        if isinstance(optimizer, optimizer_v2.OptimizerV2):
          variables = VAR_MAP_V2[name]
        else:
          variables = VAR_MAP_V1[name]

        extended_variables = [
            v + "/replica_{}".format(replica)
            for v in variables
            for replica in range(1, num_parameter_devices)
        ]
        variables = list(variables) + extended_variables
        return set([v + ":0" for v in variables])

      self.assertEqual(
          get_expected_variables(optimizer_fn,
                                 len(distribution.extended.parameter_devices)),
          set(created_variables))
Example #4
0
  def scope(self):
    """Returns a context manager selecting this DistributionStrategy as current.

    Inside a `with distribution_strategy.scope():` code block, this thread
    will use a variable creator set by `distribution_strategy`, and will
    enter its "cross-tower context".

    Returns:
      A context manager.
    """
    if has_distribution_strategy():
      _require_cross_tower_context(self)
      return _SameScopeAgainContext(self)

    def creator_with_resource_vars(*args, **kwargs):
      _require_distribution_strategy_scope(self)
      kwargs["use_resource"] = True
      return self._create_variable(*args, **kwargs)

    def disable_partitioned_variables(getter, *args, **kwargs):
      if kwargs.pop("partitioner", None) is not None:
        tf_logging.log_first_n(
            tf_logging.WARN, "Partitioned variables are disabled when using "
            "DistributionStrategy.", 1)
      return getter(*args, **kwargs)

    return _CurrentDistributionContext(
        self, variable_scope.variable_creator_scope(creator_with_resource_vars),
        variable_scope.variable_scope(
            variable_scope.get_variable_scope(),
            custom_getter=disable_partitioned_variables),
        self._default_device)
  def testCreatorStacksAreThreadLocal(self):
    devices = ["/device:CPU:0", "/device:GPU:0"]
    dist = mirrored_strategy.MirroredStrategy(devices)

    def model_fn(device_id):
      assert isinstance(device_id, int)
      def thread_creator_fn(next_creator, *args, **kwargs):
        return next_creator(*args, **kwargs) + ":thread_" + str(device_id)

      with variable_scope.variable_creator_scope(thread_creator_fn):
        # Create a variable in this scope.
        v = variable_scope.variable(1.0)

        # This will pause the current thread, and execute the other thread.
        distribute_lib.get_tower_context().merge_call(lambda _: _)
      return v

    def main_thread_creator(next_creator, *args, **kwargs):
      # We are not using the underlying next_creator for test purposes.
      del next_creator, args, kwargs
      return "main_thread"

    with context.graph_mode(), \
        dist.scope(), \
        variable_scope.variable_creator_scope(main_thread_creator):
      result = dist.call_for_each_tower(model_fn, dist.worker_device_index)
      result = dist.unwrap(result)
      expected = ["main_thread:thread_0", "main_thread:thread_1"]
      self.assertEquals(expected, result)
Example #6
0
 def call(self, inputs, mask=None, training=None):
   arguments = self.arguments
   if self._fn_expects_mask_arg:
     arguments['mask'] = mask
   if self._fn_expects_training_arg:
     arguments['training'] = training
   with variable_scope.variable_creator_scope(self._variable_creator):
     return self.function(inputs, **arguments)
Example #7
0
  def tower_local_var_scope(self, reduce_method):
    """Does not set to resource variables."""
    def create_tower_local_variable(next_creator, *args, **kwargs):
      _require_distribution_strategy_scope(self)
      kwargs["trainable"] = False
      return next_creator(*args, **kwargs)

    _require_distribution_strategy_scope(self)
    return variable_scope.variable_creator_scope(create_tower_local_variable)
Example #8
0
  def testOptimizerInsideModelFn(self, distribution, optimizer_fn):
    created_variables = []
    trainable_variables = []

    def appending_creator(next_creator, *args, **kwargs):
      v = next_creator(*args, **kwargs)
      created_variables.append(v.name)
      if "trainable" in kwargs and kwargs["trainable"]:
        trainable_variables.append(v.name)
      return v

    # Creator scope needs to be set before it's used inside
    # `distribution.scope`.
    with variable_scope.variable_creator_scope(
        appending_creator), distribution.scope():
      model_fn, dataset, layer = minimize_loss_example(
          optimizer_fn,
          use_bias=True,
          use_callable_loss=True,
          create_optimizer_inside_model_fn=True)

      iterator = distribution.distribute_dataset(dataset)

      def run_step():
        return distribution.group(
            distribution.call_for_each_tower(
                model_fn, iterator.get_next(), run_concurrently=layer.built))

      if not context.executing_eagerly():
        with self.test_session() as sess:
          run_step = sess.make_callable(run_step())
        self.evaluate(variables_lib.global_variables_initializer())

      run_step()

      def get_expected_variables(optimizer_fn, num_parameter_devices):
        variables_map = {
            "GradientDescent": ["dense/kernel", "dense/bias"],
            "Adam": [
                "dense/kernel", "dense/bias", "beta1_power", "beta2_power",
                "dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam",
                "dense/bias/Adam_1"
            ]
        }
        variables = variables_map[optimizer_fn().get_name()]
        variables.extend([
            v + "/replica_{}".format(replica)
            for v in variables
            for replica in range(1, num_parameter_devices)
        ])
        return set([v + ":0" for v in variables])

      self.assertEqual(
          get_expected_variables(optimizer_fn,
                                 len(distribution.parameter_devices)),
          set(created_variables))
Example #9
0
  def _call_func(self, args, kwargs):
    try:
      vars_at_start = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
      trainable_at_start = len(
          ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
      if self._variables_created:
        result = self._func(*args, **kwargs)
      else:
        # The first time we run, restore variables if necessary (via
        # Checkpointable).
        with variable_scope.variable_creator_scope(
            self._checkpointable_custom_creator):
          result = self._func(*args, **kwargs)

      if self._variables_created:
        # Variables were previously created, implying this is not the first
        # time the template has been called. Check to make sure that no new
        # trainable variables were created this time around.
        trainable_variables = ops.get_collection(
            ops.GraphKeys.TRAINABLE_VARIABLES)
        # If a variable that we intend to train is created as a side effect
        # of creating a template, then that is almost certainly an error.
        if trainable_at_start != len(trainable_variables):
          raise ValueError("Trainable variable created when calling a template "
                           "after the first time, perhaps you used tf.Variable "
                           "when you meant tf.get_variable: %s" %
                           (trainable_variables[trainable_at_start:],))

        # Non-trainable tracking variables are a legitimate reason why a new
        # variable would be created, but it is a relatively advanced use-case,
        # so log it.
        variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
        if vars_at_start != len(variables):
          logging.info("New variables created when calling a template after "
                       "the first time, perhaps you used tf.Variable when you "
                       "meant tf.get_variable: %s",
                       variables[vars_at_start:])
      else:
        self._variables_created = True
      return result
    except Exception as exc:
      # Reraise the exception, but append the original definition to the
      # trace.
      args = exc.args
      if not args:
        arg0 = ""
      else:
        arg0 = args[0]
      trace = "".join(_skip_common_stack_elements(self._stacktrace,
                                                  traceback.format_stack()))
      arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace)
      new_args = [arg0]
      new_args.extend(args[1:])
      exc.args = tuple(new_args)
      raise
  def save(self, session=None, checkpoint_number=None):
    """Creates a new checkpoint and manages it.

    Args:
      session: The session to evaluate variables in. Ignored when executing
        eagerly. If not provided when graph building, the default session is
        used.
      checkpoint_number: An optional integer, or an integer-dtype `Variable` or
        `Tensor`, used to number the checkpoint. If `None` (default),
        checkpoints are numbered using `checkpoint.save_counter`. Even if
        `checkpoint_number` is provided, `save_counter` is still incremented. A
        user-provided `checkpoint_number` is not incremented even if it is a
        `Variable`.

    Returns:
      The path to the new checkpoint. It is also recorded in the `checkpoints`
      and `latest_checkpoint` properies.
    """
    # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
    # slightly with a custom numbering option.
    if context.executing_eagerly():
      save_counter = self._checkpoint.save_counter
      save_counter.assign_add(1)
    else:
      if session is None:
        session = ops.get_default_session()

      def _initializing_creator(next_creator, **kwargs):
        """Initialize the save counter if it has been newly created."""
        v = next_creator(**kwargs)
        session.run(v.initializer)
        return v

      with variable_scope.variable_creator_scope(_initializing_creator):
        save_counter = self._checkpoint.save_counter
      if self._save_counter_assign is None:
        self._save_counter_assign = save_counter.assign_add(1, read_value=False)
      session.run(self._save_counter_assign)
    if checkpoint_number is None:
      checkpoint_number = save_counter
    if not isinstance(checkpoint_number, compat.integral_types):
      checkpoint_number = training_util.global_step(
          sess=session, global_step_tensor=checkpoint_number)
    prefix = "%s-%d" % (self._prefix, checkpoint_number)
    save_path = self._checkpoint.write(prefix)
    timestamp = time.time()
    # If this is an overwritten checkpoint we were previously tracking, delete
    # and reinsert it to make sure it goes to the end of the queue.
    if save_path in self._maybe_delete:
      del self._maybe_delete[save_path]
    self._maybe_delete[save_path] = timestamp
    self._latest_checkpoint = save_path
    self._sweep()
    self._record_state()
    return save_path
Example #11
0
  def scope(self):
    """Context manager setting a variable creator and `self` as current."""
    if distribution_strategy_context.has_distribution_strategy():
      raise RuntimeError("Must not nest DistributionStrategy scopes.")

    def creator(next_creator, *args, **kwargs):
      _require_distribution_strategy_scope(self)
      return next_creator(*args, **kwargs)

    return _CurrentDistributionContext(
        self, variable_scope.variable_creator_scope(creator))
    def model_fn(device_id):
      assert isinstance(device_id, int)
      def thread_creator_fn(next_creator, *args, **kwargs):
        return next_creator(*args, **kwargs) + ":thread_" + str(device_id)

      with variable_scope.variable_creator_scope(thread_creator_fn):
        # Create a variable in this scope.
        v = variable_scope.variable(1.0)

        # This will pause the current thread, and execute the other thread.
        distribute_lib.get_tower_context().merge_call(lambda _: _)
      return v
  def testSharedVariable(self):

    shared_variable_store = {}
    num_devices = 3
    creator_fns = []
    for i in range(num_devices):
      creator_fn = shared_variable_creator.make_fn(shared_variable_store, i)
      creator_fns.append(creator_fn)

    with variable_scope.variable_creator_scope(creator_fns[0]):
      v0 = variable_scope.variable(1.0, name="foo")

    with variable_scope.variable_creator_scope(creator_fns[1]):
      v1 = variable_scope.variable(1.0, name="foo")

    with variable_scope.variable_creator_scope(creator_fns[2]):
      v2 = variable_scope.variable(1.0, name="foo")

    # v1 and v2 should be same as v0
    self.assertIs(v1, v0)
    self.assertIs(v2, v0)
Example #14
0
  def scope(self):
    """Context manager setting a variable creator and `self` as current."""
    if has_distribution_strategy():
      raise RuntimeError("Must not nest DistributionStrategy scopes.")

    def creator(next_creator, *args, **kwargs):
      _require_distribution_strategy_scope(self)
      if kwargs.pop("tower_local_reduce_method", None) is not None:
        kwargs["trainable"] = False
      return next_creator(*args, **kwargs)

    return _CurrentDistributionContext(
        self, variable_scope.variable_creator_scope(creator))
Example #15
0
    def testScopeVarCreatorNestingError(self):
        def creator(next_creator, **kwargs):
            return next_creator(**kwargs)

        _assert_in_default_state(self)
        dist = _TestStrategy()
        scope = dist.scope()
        scope.__enter__()
        self.assertIs(dist, ds_context.get_strategy())
        with variable_scope.variable_creator_scope(creator):
            with self.assertRaisesRegex(
                    RuntimeError, "Variable creator scope nesting error"):
                scope.__exit__(None, None, None)
        scope.__exit__(None, None, None)
        _assert_in_default_state(self)
Example #16
0
def one_host_numpy_dataset(numpy_input, colocate_with, session):
  """Create a dataset on `colocate_with` from `numpy_input`."""
  def create_colocated_variable(next_creator, *args, **kwargs):
    kwargs["colocate_with"] = colocate_with
    return next_creator(*args, **kwargs)

  numpy_flat = nest.flatten(numpy_input)
  with variable_scope.variable_creator_scope(create_colocated_variable):
    vars_flat = tuple(variable_scope.variable(array_ops.zeros(i.shape, i.dtype),
                                              trainable=False)
                      for i in numpy_flat)
  for v, i in zip(vars_flat, numpy_flat):
    init_var_from_numpy(v, i, session)
  vars_nested = nest.pack_sequence_as(numpy_input, vars_flat)
  return dataset_ops.Dataset.from_tensor_slices(vars_nested)
Example #17
0
    def __init__(self, dataset_fn, coordinator):
        """Makes an iterable from datasets created by the given function.

    Args:
      dataset_fn: A function that returns a `Dataset`.
      coordinator: a `ClusterCoordinator` object, used to create dataset
        resources.
    """
        def disallow_variable_creation(next_creator, **kwargs):
            raise ValueError(
                "Creating variables in `dataset_fn` is not allowed.")

        if isinstance(dataset_fn, def_function.Function):
            with variable_scope.variable_creator_scope(
                    disallow_variable_creation):
                dataset_fn = dataset_fn.get_concrete_function()
        elif not isinstance(dataset_fn, tf_function.ConcreteFunction):
            with variable_scope.variable_creator_scope(
                    disallow_variable_creation):
                dataset_fn = def_function.function(
                    dataset_fn).get_concrete_function()
        self._dataset_fn = dataset_fn
        self._coordinator = coordinator
        self._element_spec = None
def one_host_numpy_dataset(numpy_input, colocate_with, session):
  """Create a dataset on `colocate_with` from `numpy_input`."""

  def create_colocated_variable(next_creator, **kwargs):
    kwargs["colocate_with"] = colocate_with
    return next_creator(**kwargs)

  numpy_flat = nest.flatten(numpy_input)
  with variable_scope.variable_creator_scope(create_colocated_variable):
    vars_flat = tuple(variable_scope.variable(array_ops.zeros(i.shape, i.dtype),
                                              trainable=False)
                      for i in numpy_flat)
  for v, i in zip(vars_flat, numpy_flat):
    init_var_from_numpy(v, i, session)
  vars_nested = nest.pack_sequence_as(numpy_input, vars_flat)
  return dataset_ops.Dataset.from_tensor_slices(vars_nested)
  def testScopeVarCreatorNestingError(self):

    def creator(next_creator, **kwargs):
      return next_creator(**kwargs)

    _assert_in_default_state(self)
    dist = _TestStrategy()
    scope = dist.scope()
    scope.__enter__()
    self.assertIs(dist, ds_context.get_strategy())
    with variable_scope.variable_creator_scope(creator):
      with self.assertRaisesRegexp(RuntimeError,
                                   "Variable creator scope nesting error"):
        scope.__exit__(None, None, None)
    scope.__exit__(None, None, None)
    _assert_in_default_state(self)
Example #20
0
def independent_buffers(parallel_device):
    """Context manager which saves parallel buffers independently.

  Creates a ParallelDevice-aware variable subclass which saves buffers for each
  device separately.

  Args:
    parallel_device: A ParallelDevice object on which variables are placed.

  Yields:
    Nothing.
  """
    with variable_scope.variable_creator_scope(
            functools.partial(_variable_creator,
                              parallel_device=parallel_device)):
        yield
  def global_step(self):
    if self._global_step is None:
      # Get the default create_global_step utility to actually call
      # self.add_variable, by setting a custom creator.
      def _owned_variable_as_creator(
          next_creator, initial_value, **kwargs):
        def _creator_as_getter(initializer, **kwargs):
          return next_creator(initial_value=initializer, **kwargs)
        return self.add_variable(
            getter=_creator_as_getter, initializer=initial_value, shape=[],
            **kwargs)

      with variable_scope.variable_creator_scope(
          _owned_variable_as_creator):
        self._global_step = training_util.create_global_step()
    return self._global_step
Example #22
0
def parameter_server_scope(is_chief, ps_job_name, num_ps_tasks):
    """Strategy to use parameter servers in eager.

  Creates SharedVariable objects for variables created in this scope. These
  SharedVariable objects will be placed round-robin on the parameter servers
  specified by the ps_job_name and num_ps_tasks arguments.

  To use parameter servers you need only to wrap your model initialization in
  this scope:

  ```
  with tf.contrib.eager.parameter_server_scope(
      is_chief, ps_job_name, num_ps_tasks):
    my_model = tf.keras.Sequential([...])  # Or
    input = tf.keras.Input(...)
    ....
    my_model = tf.keras.Model(input, output)
  my_model.compile(...)
  # or other usages of the model.
  ```

  Args:
    is_chief: Boolean. Whether this worker is responsible for initializing
      variables.
    ps_job_name: The name of the ps job in this cluster.
    num_ps_tasks: The number of ps tasks to use.

  Yields:
    a context manager.
  """
    # Note: capturing in a list to allow assignment.
    ps_index = [0]

    def variable_creator_scope(unused_next_creator, **kwargs):
        kwargs["initialize"] = is_chief
        with ops.device("/job:%s/task:%s" %
                        (ps_job_name, ps_index[0] % num_ps_tasks)):
            ps_index[0] += 1
            v = SharedVariable(**kwargs)
            if not is_chief:
                while not resource_variable_ops.var_is_initialized_op(
                        v.handle):
                    time.sleep(10)
            return v

    with variable_scope.variable_creator_scope(variable_creator_scope):
        yield
Example #23
0
    def save(self, session=None):
        """Creates a new checkpoint and manages it.

    Args:
      session: The session to evaluate variables in. Ignored when executing
        eagerly. If not provided when graph building, the default session is
        used.

    Returns:
      The path to the new checkpoint. It is also recorded in the `checkpoints`
      and `latest_checkpoint` properies.
    """
        # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
        # slightly with a custom numbering option.
        if context.executing_eagerly():
            save_counter = self._checkpoint.save_counter
            save_counter.assign_add(1)
            checkpoint_number = save_counter.numpy()
        else:
            if session is None:
                session = ops.get_default_session()

            def _initializing_creator(next_creator, **kwargs):
                """Initialize the save counter if it has been newly created."""
                v = next_creator(**kwargs)
                session.run(v.initializer)
                return v

            with variable_scope.variable_creator_scope(_initializing_creator):
                save_counter = self._checkpoint.save_counter
            if self._save_counter_assign is None:
                self._save_counter_assign = save_counter.assign_add(
                    1, read_value=True)
            checkpoint_number = session.run(self._save_counter_assign)
        prefix = "%s-%d" % (self._prefix, checkpoint_number)
        save_path = self._checkpoint.write(prefix)
        timestamp = time.time()
        # If this is an overwritten checkpoint we were previously tracking, delete
        # and reinsert it to make sure it goes to the end of the queue.
        if save_path in self._maybe_delete:
            del self._maybe_delete[save_path]
        self._maybe_delete[save_path] = timestamp
        self._latest_checkpoint = save_path
        self._sweep()
        self._record_state()
        return save_path
Example #24
0
def parameter_server_scope(is_chief, ps_job_name, num_ps_tasks):
  """Strategy to use parameter servers in eager.

  Creates SharedVariable objects for variables created in this scope. These
  SharedVariable objects will be placed round-robin on the parameter servers
  specified by the ps_job_name and num_ps_tasks arguments.

  To use parameter servers you need only to wrap your model initialization in
  this scope:

  ```
  with tf.contrib.eager.parameter_server_scope(
      is_chief, ps_job_name, num_ps_tasks):
    my_model = tf.keras.Sequential([...])  # Or
    input = tf.keras.Input(...)
    ....
    my_model = tf.keras.Model(input, output)
  my_model.compile(...)
  # or other usages of the model.
  ```

  Args:
    is_chief: Boolean. Whether this worker is responsible for initializing
      variables.
    ps_job_name: The name of the ps job in this cluster.
    num_ps_tasks: The number of ps tasks to use.

  Yields:
    a context manager.
  """
  # Note: capturing in a list to allow assignment.
  ps_index = [0]

  def variable_creator_scope(unused_next_creator, **kwargs):
    kwargs["initialize"] = is_chief
    with ops.device(
        "/job:%s/task:%s" % (ps_job_name, ps_index[0] % num_ps_tasks)):
      ps_index[0] += 1
      v = SharedVariable(**kwargs)
      if not is_chief:
        while not resource_variable_ops.var_is_initialized_op(v.handle):
          time.sleep(10)
      return v

  with variable_scope.variable_creator_scope(variable_creator_scope):
    yield
Example #25
0
    def test_keras_layer_add_weight(self):
        class Layer(base_layer.Layer):
            def __init__(self):
                super().__init__()
                self.w = self.add_weight(
                    shape=(2, ),
                    initializer=lambda shape, dtype: [0, 1],
                    trainable=True)
                self.b = self.add_weight(
                    shape=(2, ),
                    initializer=lambda shape, dtype: [2, 3],
                    trainable=False)

        def sharded_variable_creator(next_creator, **kwargs):
            v1_value = kwargs['initial_value']()[0:1]
            v2_value = kwargs['initial_value']()[1:]

            kwargs['initial_value'] = v1_value
            kwargs['shape'] = (1, )
            v1 = next_creator(**kwargs)

            kwargs['initial_value'] = v2_value
            kwargs['shape'] = (1, )
            v2 = next_creator(**kwargs)

            return sharded_variable.ShardedVariable([v1, v2])

        with variable_scope.variable_creator_scope(sharded_variable_creator):
            layer = Layer()

        self.assertLen(layer.trainable_weights, 2)
        self.assertEqual(layer.trainable_weights[0], [0])
        self.assertEqual(layer.trainable_weights[1], [1])
        self.assertLen(layer.non_trainable_weights, 2)
        self.assertEqual(layer.non_trainable_weights[0], [2])
        self.assertEqual(layer.non_trainable_weights[1], [3])
        self.assertAllEqual(
            layer.weights,
            layer.trainable_weights + layer.non_trainable_weights)
        self.assertAllEqual(layer.trainable_weights, layer.trainable_variables)
        self.assertAllEqual(layer.weights, layer.variables)

        checkpoint_deps = set(dep.ref
                              for dep in layer._checkpoint_dependencies)
        self.assertEqual(checkpoint_deps, set([layer.w, layer.b]))
Example #26
0
    def testVariableCreatorScope(self):
        created_variables = []
        captured_variables = []

        @def_function.function
        def f():
            if not created_variables:
                created_variables.append(variables.Variable(1.))
            return created_variables[0] + 1.

        def capture_creator(next_creator, **kwargs):
            created = next_creator(**kwargs)
            captured_variables.append(created)
            return created

        with variable_scope.variable_creator_scope(capture_creator):
            f()
        self.assertEqual(created_variables, captured_variables)
  def testVariableCreatorScope(self):
    created_variables = []
    captured_variables = []

    @def_function.function
    def f():
      if not created_variables:
        created_variables.append(variables.Variable(1.))
      return created_variables[0] + 1.

    def capture_creator(next_creator, **kwargs):
      created = next_creator(**kwargs)
      captured_variables.append(created)
      return created

    with variable_scope.variable_creator_scope(capture_creator):
      f()
    self.assertEqual(created_variables, captured_variables)
    def global_step(self):
        if self._global_step is None:
            # Get the default create_global_step utility to actually call
            # self.add_variable, by setting a custom creator.
            def _owned_variable_as_creator(next_creator, initial_value,
                                           **kwargs):
                def _creator_as_getter(initializer, **kwargs):
                    return next_creator(initial_value=initializer, **kwargs)

                return self.add_variable(getter=_creator_as_getter,
                                         initializer=initial_value,
                                         shape=[],
                                         **kwargs)

            with variable_scope.variable_creator_scope(
                    _owned_variable_as_creator):
                self._global_step = training_util.create_global_step()
        return self._global_step
Example #29
0
    def testVariableCreatingCustomGetter(self, variable_type, stack_entries):
        use_resource = variable_type == "ResourceVariable"

        if tf.executing_eagerly() and not use_resource:
            self.skipTest("Ref variables not supported in eager mode.")

        def my_custom_getter(getter, **kwargs):
            var = getter(**kwargs)
            # Create an additional variable in the getter which is not returned.
            kwargs["name"] += "_additional"
            getter(**kwargs)
            return var

        variables = []

        with contextlib2.ExitStack() as stack:
            stack.enter_context(
                tf.variable_scope("", use_resource=use_resource))
            for stack_entry in stack_entries:
                if stack_entry == "notify":
                    stack.enter_context(
                        util.notify_about_variables(variables.append))
                elif stack_entry == "custom_getter":
                    stack.enter_context(
                        tf.variable_scope("", custom_getter=my_custom_getter))
                elif stack_entry == "variable_creator":
                    stack.enter_context(
                        variable_scope_ops.variable_creator_scope(
                            my_custom_getter))
                else:
                    raise AssertionError

            v = tf.get_variable("v", [])

        self.assertVariableType(v, use_resource)
        if stack_entries == ["variable_creator", "notify"]:
            # When a variable creator is entered before `notify_about_variables` there
            # is no way for us to identify what dditional variables that creator
            # created.
            self.assertEqual([v.name for v in variables], [u"v:0"])
        else:
            self.assertEqual([v.name for v in variables],
                             [u"v:0", u"v_additional:0"])
Example #30
0
  def call(self, inputs, mask=None, training=None):
    # We must copy for thread safety, but it only needs to be a shallow copy.
    kwargs = {k: v for k, v in self.arguments.items()}
    if self._fn_expects_mask_arg:
      kwargs['mask'] = mask
    if self._fn_expects_training_arg:
      kwargs['training'] = training

    created_variables = []
    def _variable_creator(next_creator, **kwargs):
      var = next_creator(**kwargs)
      created_variables.append(var)
      return var

    with backprop.GradientTape(watch_accessed_variables=True) as tape,\
        variable_scope.variable_creator_scope(_variable_creator):
      result = self.function(inputs, **kwargs)
    self._check_variables(created_variables, tape.watched_variables())
    return result
Example #31
0
  def colocate_vars_with(self, colocate_with_variable):
    """Scope that controls which devices variables will be created on.

    No operations should be added to the graph inside this scope, it
    should only be used when creating variables (some implementations
    work by changing variable creation, others work by using a
    tf.colocate_with() scope).

    This may only be used inside `self.scope()`.

    Example usage:

    ```
    with distribution_strategy.scope():
      var1 = tf.get_variable(...)
      with distribution_strategy.colocate_vars_with(v1):
        # var2 and var3 will be created on the same device(s) as var1
        var2 = tf.get_variable(...)
        var3 = tf.get_variable(...)

      def fn(v1, v2, v3):
        # operates on v1 from var1, v2 from var2, and v3 from var3

      # `fn` runs on every device `v1` is on, `v2` and `v3` will be there too.
      distribution_strategy.update(v1, fn, v2, v3)
    ```

    Args:
      colocate_with_variable: A created in `self.scope()`. Variables created
        while in the returned context manager will be on the same set of
        devices as `colocate_with_variable`.

    Returns:
      A context manager.
    """
    def create_colocated_variable(next_creator, *args, **kwargs):
      _require_distribution_strategy_scope(self)
      kwargs["use_resource"] = True
      kwargs["colocate_with"] = colocate_with_variable
      return next_creator(*args, **kwargs)

    _require_distribution_strategy_scope(self)
    return variable_scope.variable_creator_scope(create_colocated_variable)
Example #32
0
  def colocate_vars_with(self, colocate_with_variable):
    """Scope that controls which devices variables will be created on.

    No operations should be added to the graph inside this scope, it
    should only be used when creating variables (some implementations
    work by changing variable creation, others work by using a
    tf.colocate_with() scope).

    This may only be used inside `self.scope()`.

    Example usage:

    ```
    with distribution_strategy.scope():
      var1 = tf.get_variable(...)
      with distribution_strategy.colocate_vars_with(v1):
        # var2 and var3 will be created on the same device(s) as var1
        var2 = tf.get_variable(...)
        var3 = tf.get_variable(...)

      def fn(v1, v2, v3):
        # operates on v1 from var1, v2 from var2, and v3 from var3

      # `fn` runs on every device `v1` is on, `v2` and `v3` will be there too.
      distribution_strategy.update(v1, fn, v2, v3)
    ```

    Args:
      colocate_with_variable: A created in `self.scope()`. Variables created
        while in the returned context manager will be on the same set of
        devices as `colocate_with_variable`.

    Returns:
      A context manager.
    """
    def create_colocated_variable(next_creator, *args, **kwargs):
      _require_distribution_strategy_scope(self)
      kwargs["use_resource"] = True
      kwargs["colocate_with"] = colocate_with_variable
      return next_creator(*args, **kwargs)

    _require_distribution_strategy_scope(self)
    return variable_scope.variable_creator_scope(create_colocated_variable)
Example #33
0
    def save_v2(self, filepath):
        # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
        # slightly with a custom numbering option.
        if context.executing_eagerly():
            save_counter = self._checkpoint.save_counter
            save_counter.assign_add(1)
            session = None
        else:
            session = ops.get_default_session()

            def _initializing_creator(next_creator, **kwargs):
                """Initialize the save counter if it has been newly created."""
                v = next_creator(**kwargs)
                session.run(v.initializer)
                return v

            with variable_scope.variable_creator_scope(_initializing_creator):
                save_counter = self._checkpoint.save_counter
            if self._save_counter_assign is None:
                self._save_counter_assign = save_counter.assign_add(
                    1, read_value=False)
            session.run(self._save_counter_assign)
        save_path = self._checkpoint.write(filepath)
        timestamp = time.time()
        # If this is an overwritten checkpoint we were previously tracking, delete
        # and reinsert it to make sure it goes to the end of the queue.
        if save_path in self._maybe_delete:
            del self._maybe_delete[save_path]
        self._maybe_delete[save_path] = timestamp
        self._latest_checkpoint = save_path
        # Before deleting anything we update the Checkpoint proto with the new
        # checkpoint. We'll go back and correct it after cleaning up old files, but
        # a preemption while deleting will be more likely to see the new checkpoint
        # this way.
        self._record_state()
        self._sweep()
        # Write out the Checkpoint proto a second time, now without the deleted
        # checkpoints.
        self._record_state()
        return save_path
Example #34
0
  def testVariableCreatingCustomGetter(self, variable_type, stack_entries):
    use_resource = variable_type == "ResourceVariable"

    if tf.executing_eagerly() and not use_resource:
      self.skipTest("Ref variables not supported in eager mode.")

    def my_custom_getter(getter, **kwargs):
      var = getter(**kwargs)
      # Create an additional variable in the getter which is not returned.
      kwargs["name"] += "_additional"
      getter(**kwargs)
      return var

    variables = []

    with contextlib2.ExitStack() as stack:
      stack.enter_context(tf.variable_scope("", use_resource=use_resource))
      for stack_entry in stack_entries:
        if stack_entry == "notify":
          stack.enter_context(util.notify_about_variables(variables.append))
        elif stack_entry == "custom_getter":
          stack.enter_context(
              tf.variable_scope("", custom_getter=my_custom_getter))
        elif stack_entry == "variable_creator":
          stack.enter_context(
              variable_scope_ops.variable_creator_scope(my_custom_getter))
        else:
          raise AssertionError

      v = tf.get_variable("v", [])

    self.assertVariableType(v, use_resource)
    if stack_entries == ["variable_creator", "notify"]:
      # When a variable creator is entered before `notify_about_variables` there
      # is no way for us to identify what dditional variables that creator
      # created.
      self.assertEqual([v.name for v in variables], [u"v:0"])
    else:
      self.assertEqual([v.name for v in variables], [u"v:0", u"v_additional:0"])
Example #35
0
    def scope(self):
        """Returns a context manager selecting this DistributionStrategy as current.

    Inside a `with distribution_strategy.scope():` code block, this thread
    will use a variable creator set by `distribution_strategy`, and will
    enter its "cross-tower context".

    Returns:
      A context manager.
    """
        if has_distribution_strategy():
            _require_cross_tower_context(self)
            return _SameScopeAgainContext(self)

        def creator_with_resource_vars(*args, **kwargs):
            _require_distribution_strategy_scope(self)
            kwargs["use_resource"] = True
            return self._create_variable(*args, **kwargs)

        return _CurrentDistributionContext(
            self,
            variable_scope.variable_creator_scope(creator_with_resource_vars))
 def run(self):
   # pylint: disable=protected-access
   self.graph._variable_creator_stack = self._variable_creator_stack
   self.should_run.wait()
   self.should_run.clear()
   try:
     if self.coord.should_stop():
       return
     with self.coord.stop_on_exception(), \
         context.context()._mode(self.context_mode), \
         context.context().device_policy(self.context_device_policy), \
         _enter_graph(self.graph), \
         MirroredTowerContext(self.distribution, self.tower_id), \
         ops.device(self.device), \
         ops.name_scope(self._name_scope), \
         variable_scope.variable_scope(
             self._captured_var_scope, reuse=self.tower_id > 0), \
         variable_scope.variable_creator_scope(self.variable_creator_fn):
       self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
       self.done = True
   finally:
     self.has_paused.set()
 def run(self):
   # pylint: disable=protected-access
   self.graph._variable_creator_stack = self._variable_creator_stack
   self.should_run.wait()
   self.should_run.clear()
   try:
     if self.coord.should_stop():
       return
     with self.coord.stop_on_exception(), \
         context.context()._mode(self.context_mode), \
         context.context().device_policy(self.context_device_policy), \
         _enter_graph(self.graph), \
         MirroredTowerContext(self.distribution, self.tower_id), \
         ops.device(self.device), \
         ops.name_scope(self._name_scope), \
         variable_scope.variable_scope(
             self._captured_var_scope, reuse=self.tower_id > 0), \
         variable_scope.variable_creator_scope(self.variable_creator_fn):
       self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
       self.done = True
   finally:
     self.has_paused.set()
Example #38
0
    def test_tf_optimizer_with_sparse_gradient_using_keras(self):
        import tensorflow as tf

        ids = np.random.randint(0, 10, size=[40])
        labels = np.random.randint(0, 5, size=[40])
        id_rdd = self.sc.parallelize(ids)
        label_rdd = self.sc.parallelize(labels)
        training_rdd = id_rdd.zip(label_rdd).map(lambda x: [x[0], x[1]])
        with tf.Graph().as_default():
            dataset = TFDataset.from_rdd(training_rdd,
                                         names=["ids", "labels"],
                                         shapes=[[], []],
                                         types=[tf.int32, tf.int32],
                                         batch_size=8)
            from tensorflow.python.ops import variable_scope

            def variable_creator(**kwargs):
                kwargs["use_resource"] = False
                return variable_scope.default_variable_creator(None, **kwargs)

            getter = lambda next_creator, **kwargs: variable_creator(**kwargs)
            with variable_scope.variable_creator_scope(getter):
                words_input = tf.keras.layers.Input(shape=(),
                                                    name='words_input')
                embedding_layer = tf.keras.layers.Embedding(
                    input_dim=10, output_dim=5, name='word_embedding')
                word_embeddings = embedding_layer(words_input)
                embedding = tf.keras.layers.Flatten()(word_embeddings)
                output = tf.keras.layers.Dense(5,
                                               activation="softmax")(embedding)
                model = tf.keras.models.Model(inputs=[words_input],
                                              outputs=[output])
                model.compile(optimizer="sgd",
                              loss="sparse_categorical_crossentropy")

            optimizer = TFOptimizer.from_keras(model, dataset)
            optimizer.optimize(end_trigger=MaxEpoch(1))
            optimizer.sess.close()
Example #39
0
 def run(self):
   # pylint: disable=protected-access
   self.should_run.wait()
   self.should_run.clear()
   try:
     if self.coord.should_stop():
       return
     with self.coord.stop_on_exception(), \
         _enter_graph(self._init_graph, self._init_in_eager), \
         _enter_graph(self.graph, self.in_eager,
                      self._variable_creator_stack), \
         context.context().device_policy(self.context_device_policy), \
         MirroredReplicaContext(self.distribution, constant_op.constant(
             self.replica_id, dtypes.int32)), \
         ops.device(self.device), \
         ops.name_scope(self._name_scope), \
         variable_scope.variable_scope(
             self._captured_var_scope, reuse=self.replica_id > 0), \
         variable_scope.variable_creator_scope(self.variable_creator_fn):
       self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
       self.done = True
   finally:
     self.has_paused.set()
Example #40
0
    def tower_local_var_scope(self, reduce_method):
        """Inside this scope, new variables will not be mirrored.

    There will still be one component variable per tower, but there is
    no requirement that they stay in sync. Instead, when saving them
    or calling `fetch()`, we use the value that results when calling
    `reduce()` on all the towers' variables.

    Note: tower-local implies not trainable. Instead, it is expected
    that each tower will directly update (using `assign_add()` or
    whatever) its local variable instance but only the aggregated
    value (accessible using `fetch()`) will be exported from the
    model. When it is acceptable to only aggregate on export, we
    greatly reduce communication overhead by using tower-local
    variables.

    Note: All component variables will be initialized to the same
    value, using the initialization expression from the first tower.
    The values will match even if the initialization expression uses
    random numbers.

    Args:
      reduce_method: String used as a `method_string` to `reduce()`
        to get the value to save when checkpointing.

    Returns:
      A context manager.
    """
        def create_tower_local_variable(next_creator, *args, **kwargs):
            _require_distribution_strategy_scope(self)
            kwargs["use_resource"] = True
            kwargs["tower_local_reduce_method"] = reduce_method
            return next_creator(*args, **kwargs)

        _require_distribution_strategy_scope(self)
        return variable_scope.variable_creator_scope(
            create_tower_local_variable)
Example #41
0
 def run(self):
   self.should_run.wait()
   self.should_run.clear()
   try:
     if self.coord.should_stop():
       return
     # TODO(josh11b): Use current logical device instead of 0 here.
     with self.coord.stop_on_exception(), \
         _enter_graph(self._init_graph, self._init_in_eager), \
         _enter_graph(self.graph, self.in_eager,
                      self._variable_creator_stack), \
         context.context().device_policy(self.context_device_policy), \
         MirroredReplicaContext(self.distribution, constant_op.constant(
             self.replica_id, dtypes.int32)), \
         ops.device(self.device_map.logical_to_actual_devices(0)[
             self.replica_id]), \
         ops.name_scope(self._name_scope), \
         variable_scope.variable_scope(
             self._captured_var_scope, reuse=self.replica_id > 0), \
         variable_scope.variable_creator_scope(self.variable_creator_fn):
       self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
       self.done = True
   finally:
     self.has_paused.set()
Example #42
0
  def tower_local_var_scope(self, reduce_method):
    """Inside this scope, new variables will not be mirrored.

    There will still be one component variable per tower, but there is
    no requirement that they stay in sync. Instead, when saving them
    or calling `fetch()/read_var()`, we use the value that
    results when calling `reduce()` on all the towers' variables.

    Note: tower-local implies not trainable. Instead, it is expected
    that each tower will directly update (using `assign_add()` or
    whatever) its local variable instance but only the aggregated
    value (accessible using `fetch()`) will be exported from the
    model. When it is acceptable to only aggregate on export, we
    greatly reduce communication overhead by using tower-local
    variables.

    Note: All component variables will be initialized to the same
    value, using the initialization expression from the first tower.
    The values will match even if the initialization expression uses
    random numbers.

    Args:
      reduce_method: String used as a `method_string` to `reduce()`
        to get the value to save when checkpointing.

    Returns:
      A context manager.
    """
    def create_tower_local_variable(next_creator, *args, **kwargs):
      _require_distribution_strategy_scope(self)
      kwargs["use_resource"] = True
      kwargs["tower_local_reduce_method"] = reduce_method
      return next_creator(*args, **kwargs)

    _require_distribution_strategy_scope(self)
    return variable_scope.variable_creator_scope(create_tower_local_variable)
Example #43
0
    def call(self, inputs, mask=None, training=None):
        # Disallow two variables with the same name.
        kwargs = {}
        if self._fn_expects_mask_arg:
            kwargs['mask'] = mask
        if self._fn_expects_training_arg:
            kwargs['training'] = training

        call_fn = self._function_with_args
        if kwargs:
            call_fn = functools.partial(call_fn, **kwargs)

        created_variables = []

        def _variable_creator(next_creator, **kwargs):
            var = next_creator(**kwargs)
            created_variables.append(var)
            return var

        with backprop.GradientTape(watch_accessed_variables=True) as tape,\
            variable_scope.variable_creator_scope(_variable_creator):
            result = call_fn(inputs)
        self._check_variables(created_variables, tape.watched_variables())
        return result
Example #44
0
 def wrapped_fn(*args, **kwds):
   with variable_scope.variable_creator_scope(scope):
     # __wrapped__ allows AutoGraph to swap in a converted function.
     return wrapped_fn.__wrapped__(*args, **kwds)
Example #45
0
 def wrapped_fn(*args, **kwds):
   with variable_scope.variable_creator_scope(scope):
     # __wrapped__ allows AutoGraph to swap in a converted function. We give
     # the function a weak reference to itself to avoid a reference cycle.
     return weak_wrapped_fn().__wrapped__(*args, **kwds)
Example #46
0
  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Apply gradients to global variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    applies gradients.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        `compute_gradients()`.
      global_step: Optional `Variable` to increment by one after the variables
        have been updated.
      name: Optional name for the returned operation.  Default to the name
        passed to the `Optimizer` constructor.

    Returns:
      An `Operation` that applies the specified gradients. If `global_step`
      was not None, that operation also increments `global_step`.
    """
    local_vars = [v for g, v in grads_and_vars if g is not None]
    grads = [g for g, v in grads_and_vars if g is not None]

    def _variable_creator(next_creator, collections, **kwargs):
      if not collections:
        collections = [ops.GraphKeys.LOCAL_VARIABLES]
      elif ops.GraphKeys.GLOBAL_VARIABLES in collections:
        collections = list(collections)
        collections.append(ops.GraphKeys.LOCAL_VARIABLES)
        collections.remove(ops.GraphKeys.GLOBAL_VARIABLES)
      return next_creator(collections=collections, **kwargs)

    # theta = theta - lr * grad
    with variable_scope.variable_creator_scope(_variable_creator):
      local_update_op = self._opt.apply_gradients(grads_and_vars)

    # a = a + grad
    update_ops = []
    update_ops.append(local_update_op)
    grad_vars = [self._grad_map[var] for var in local_vars]
    for g, grad_var in zip(grads, grad_vars):
      update_ops.append(state_ops.assign_add(grad_var, g))

    global_center_vars = [self._global_map[var] for var in local_vars]

    # update global variables.
    def _Update_global_variables():
      global_norm = []
      # a = a / t
      for g in grad_vars:
        global_norm.append(state_ops.assign(g, g / self._period))
      # apply
      with ops.control_dependencies(global_norm):
        apply_global_op = self._opt.apply_gradients(
            zip(grad_vars, global_center_vars))

      # pull
      with ops.control_dependencies([apply_global_op]):
        update_ops = []
        if global_step:
          with ops.colocate_with(global_step):
            update_ops.append(state_ops.assign_add(global_step, 1))

        for lvar in local_vars:
          g_val = self._global_map[lvar].read_value()
          update_ops.append(state_ops.assign(lvar, g_val))
        for grad_var in grad_vars:
          update_ops.append(
              state_ops.assign(grad_var, array_ops.zeros_like(grad_var)))
        variable_update = control_flow_ops.group(*(update_ops))
      return variable_update

    local_update = state_ops.assign_add(
        self._local_step, 1, name='local_step_update').op

    with ops.control_dependencies([local_update]):
      condition = math_ops.equal(
          math_ops.mod(self._local_step, self._period), 0)
    with ops.control_dependencies(update_ops):
      conditional_update = control_flow_ops.cond(
          condition, _Update_global_variables, control_flow_ops.no_op)
    return conditional_update
Example #47
0
 def wrapped_fn(*args, **kwds):
   with variable_scope.variable_creator_scope(scope):
     return fn(*args, **kwds)
Example #48
0
 def wrapped(*args, **kwargs):
     with variable_scope.variable_creator_scope(
             self.variable_creator_scope):
         return fn(*args, **kwargs)
Example #49
0
 def wrapped_fn(*args, **kwds):
     with variable_scope.variable_creator_scope(scope):
         # __wrapped__ allows AutoGraph to swap in a converted function.
         return wrapped_fn.__wrapped__(*args, **kwds)
Example #50
0
 def call(self, inputs, mask=None):
   arguments = self.arguments
   if generic_utils.has_arg(self.function, 'mask'):
     arguments['mask'] = mask
   with variable_scope.variable_creator_scope(self._variable_creator):
     return self.function(inputs, **arguments)
Example #51
0
 def wrapped_fn(*args, **kwds):
   with variable_scope.variable_creator_scope(scope):
     return fn(*args, **kwds)
Example #52
0
 def wrapped_fn(*args, **kwds):
     with variable_scope.variable_creator_scope(scope):
         # __wrapped__ allows AutoGraph to swap in a converted function. We give
         # the function a weak reference to itself to avoid a reference cycle.
         return weak_wrapped_fn().__wrapped__(*args, **kwds)
Example #53
0
    def testOptimizerInsideModelFn(self, distribution, optimizer_fn, is_tpu):
        created_variables = []
        trainable_variables = []

        def appending_creator(next_creator, *args, **kwargs):
            v = next_creator(*args, **kwargs)
            created_variables.append(v.name)
            if "trainable" in kwargs and kwargs["trainable"]:
                trainable_variables.append(v.name)
            return v

        # Creator scope needs to be set before it's used inside
        # `distribution.scope`.
        with variable_scope.variable_creator_scope(
                appending_creator), distribution.scope():
            model_fn, dataset_fn, layer = minimize_loss_example(
                optimizer_fn,
                use_bias=True,
                use_callable_loss=True,
                create_optimizer_inside_model_fn=True)

            iterator = distribution.distribute_dataset(
                dataset_fn).make_one_shot_iterator()

            def run_step():
                return distribution.group(
                    distribution.call_for_each_tower(
                        model_fn,
                        iterator.get_next(),
                        run_concurrently=layer.built))

            if not context.executing_eagerly():
                with self.test_session() as sess:
                    if is_tpu:
                        sess.run(tpu.initialize_system())
                    run_step = sess.make_callable(run_step())
                self.evaluate(variables_lib.global_variables_initializer())

            run_step()

            if is_tpu:
                with self.test_session() as sess:
                    sess.run(tpu.shutdown_system())

            def get_expected_variables(optimizer_fn, num_parameter_devices):
                variables_map = {
                    "GradientDescent": ["dense/kernel", "dense/bias"],
                    "Adam": [
                        "dense/kernel", "dense/bias", "beta1_power",
                        "beta2_power", "dense/kernel/Adam",
                        "dense/kernel/Adam_1", "dense/bias/Adam",
                        "dense/bias/Adam_1"
                    ]
                }
                variables = variables_map[optimizer_fn().get_name()]
                variables.extend([
                    v + "/replica_{}".format(replica) for v in variables
                    for replica in range(1, num_parameter_devices)
                ])
                return set([v + ":0" for v in variables])

            self.assertEqual(
                get_expected_variables(optimizer_fn,
                                       len(distribution.parameter_devices)),
                set(created_variables))
Example #54
0
def lazy_init_scope():
  with variable_scope.variable_creator_scope(_lazy_init_variable_creator):
    yield
Example #55
0
 def __call__(self, *args, **kwargs):
   with variable_scope.variable_creator_scope(self.variable_creator_scope):
     return self._fn(*args, **kwargs)
Example #56
0
 def wrapped(*args, **kwargs):
   with variable_scope.variable_creator_scope(self.variable_creator_scope):
     return fn(*args, **kwargs)
 def __call__(self, *args, **kwargs):
     with variable_scope.variable_creator_scope(
             self.variable_creator_scope):
         return self._fn(*args, **kwargs)
  def save(self, checkpoint_number=None, check_interval=True, options=None):
    """Creates a new checkpoint and manages it.

    Args:
      checkpoint_number: An optional integer, or an integer-dtype `Variable` or
        `Tensor`, used to number the checkpoint. If `None` (default),
        checkpoints are numbered using `checkpoint.save_counter`. Even if
        `checkpoint_number` is provided, `save_counter` is still incremented. A
        user-provided `checkpoint_number` is not incremented even if it is a
        `Variable`.
      check_interval: An optional boolean. The argument is only effective when
        `checkpoint_interval` is passed into the manager. If `True`, the manager
        will only save the checkpoint if the interval between checkpoints is
        larger than `checkpoint_interval`. Otherwise it will always save the
        checkpoint unless a checkpoint has already been saved for the current
        step.
      options: Optional `tf.train.CheckpointOptions` object. This argument only
        works with TF2 checkpoint objects. For example, options =
        tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')

    Returns:
      The path to the new checkpoint. It is also recorded in the `checkpoints`
      and `latest_checkpoint` properties. `None` if no checkpoint is saved.
    """
    if self._checkpoint_interval is not None:
      current_step = _evaluate(self._step_counter)
      if self._last_checkpoint_step is not None:
        if current_step == self._last_checkpoint_step:
          return None
        if check_interval and current_step < (
            self._last_checkpoint_step + self._checkpoint_interval):
          return None
      self._last_checkpoint_step = current_step

    # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
    # slightly with a custom numbering option.
    if context.executing_eagerly():
      save_counter = self._checkpoint.save_counter
      save_counter.assign_add(1)
      session = None
    else:
      session = ops.get_default_session()

      def _initializing_creator(next_creator, **kwargs):
        """Initialize the save counter if it has been newly created."""
        v = next_creator(**kwargs)
        session.run(v.initializer)
        return v

      with variable_scope.variable_creator_scope(_initializing_creator):
        save_counter = self._checkpoint.save_counter
      if self._save_counter_assign is None:
        self._save_counter_assign = save_counter.assign_add(1, read_value=False)
      session.run(self._save_counter_assign)
    if checkpoint_number is None:
      checkpoint_number = save_counter
    if not isinstance(checkpoint_number, compat.integral_types):
      checkpoint_number = training_util.global_step(
          sess=session, global_step_tensor=checkpoint_number)
    prefix = "%s-%d" % (self._prefix, checkpoint_number)
    if options is None:
      save_path = self._checkpoint.write(prefix)
    else:
      save_path = self._checkpoint.write(prefix, options=options)
    timestamp = time.time()
    # If this is an overwritten checkpoint we were previously tracking, delete
    # and reinsert it to make sure it goes to the end of the queue.
    if save_path in self._maybe_delete:
      del self._maybe_delete[save_path]
    self._maybe_delete[save_path] = timestamp
    self._latest_checkpoint = save_path
    # Before deleting anything we update the Checkpoint proto with the new
    # checkpoint. We'll go back and correct it after cleaning up old files, but
    # a preemption while deleting will be more likely to see the new checkpoint
    # this way.
    self._record_state()
    self._sweep()
    # Write out the Checkpoint proto a second time, now without the deleted
    # checkpoints.
    self._record_state()
    return save_path
    def testOptimizerInsideModelFn(self, distribution, optimizer_fn):
        created_variables = []
        trainable_variables = []

        def appending_creator(next_creator, *args, **kwargs):
            v = next_creator(*args, **kwargs)
            created_variables.append(v.name)
            if "trainable" in kwargs and kwargs["trainable"]:
                trainable_variables.append(v.name)
            return v

        # Creator scope needs to be set before it's used inside
        # `distribution.scope`.
        with variable_scope.variable_creator_scope(
                appending_creator), distribution.scope():
            model_fn, dataset_fn, layer = minimize_loss_example(
                optimizer_fn,
                use_bias=True,
                use_callable_loss=True,
                create_optimizer_inside_model_fn=True)

            def step_fn(ctx, inputs):
                del ctx  # Unused
                return distribution.group(
                    distribution.call_for_each_replica(model_fn,
                                                       args=(inputs, )))

            iterator = self._get_iterator(
                distribution.distribute_dataset(dataset_fn))

            def run_step():
                return distribution.run_steps_on_dataset(step_fn,
                                                         iterator,
                                                         iterations=1).run_op

            self.evaluate(distribution.initialize())
            if not context.executing_eagerly():
                with self.cached_session() as sess:
                    run_step = sess.make_callable(run_step())
            self.evaluate(variables_lib.global_variables_initializer())

            run_step()

            self.evaluate(distribution.finalize())

            def get_expected_variables(optimizer_fn, num_parameter_devices):
                variables_map = {
                    "GradientDescent": ["dense/kernel", "dense/bias"],
                    "Adagrad": [
                        "dense/kernel/Adagrad", "dense/kernel",
                        "dense/bias/Adagrad", "dense/bias"
                    ]
                }
                variables = variables_map[optimizer_fn().get_name()]
                variables.extend([
                    v + "/replica_{}".format(replica) for v in variables
                    for replica in range(1, num_parameter_devices)
                ])
                return set([v + ":0" for v in variables])

            self.assertEqual(
                get_expected_variables(optimizer_fn,
                                       len(distribution.parameter_devices)),
                set(created_variables))