Exemplo n.º 1
0
 def model_fn():
   vs = []
   vs.append(variable_scope.variable(1.0, name="foo/bar"))
   vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
   vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
   vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
   distribute_lib.get_tower_context().merge_call(lambda _: _)
   return vs
Exemplo n.º 2
0
    def model_fn():
      v0 = variable_scope.get_variable("var-thread0", [1])
      with variable_scope.variable_scope("common"):
        v1 = variable_scope.get_variable("var-thread1", [1])
        # This will pause the current thread, and execute the other thread.
        distribute_lib.get_tower_context().merge_call(lambda _: _)
        v2 = variable_scope.get_variable("var-thread2", [1])

      return v0, v1, v2
Exemplo n.º 3
0
 def set_non_tensor_output(self, name, output):
   """Set `output` with `name` to be captured as a non tensor output."""
   if distribute_lib.get_cross_tower_context():
     self._non_tensor_outputs[name] = output
   else:
     def merge_fn(distribution, value):
       # NOTE(priyag): For non tensor outputs, we simply return all the values
       # in a list as aggregation doesn't make sense on non tensors.
       self._non_tensor_outputs[name] = distribution.unwrap(value)
     distribute_lib.get_tower_context().merge_call(merge_fn, output)
Exemplo n.º 4
0
    def model_fn(device_id):
      assert isinstance(device_id, int)
      def thread_creator_fn(next_creator, *args, **kwargs):
        return next_creator(*args, **kwargs) + ":thread_" + str(device_id)

      with variable_scope.variable_creator_scope(thread_creator_fn):
        # Create a variable in this scope.
        v = variable_scope.variable(1.0)

        # This will pause the current thread, and execute the other thread.
        distribute_lib.get_tower_context().merge_call(lambda _: _)
      return v
Exemplo n.º 5
0
 def model_fn(features):
   with variable_scope.variable_scope("common"):
     layer1 = core.Dense(1)
     layer1(features)
     layer2 = core.Dense(1)
     layer2(features)
     # This will pause the current thread, and execute the other thread.
     distribute_lib.get_tower_context().merge_call(lambda _: _)
     layer3 = core.Dense(1)
     layer3(features)
     return [(layer1.kernel, layer1.bias),
             (layer2.kernel, layer2.bias),
             (layer3.kernel, layer3.bias)]
Exemplo n.º 6
0
 def merge_fn(dist, s):
   self.assertIs(distribute._default_distribution_strategy, dist)
   self.assertIs(None, distribute.get_tower_context())
   self.assertIs(dist, distribute.get_cross_tower_context())
   self.assertIs(dist, distribute.get_distribution_strategy())
   self.assertFalse(distribute.has_distribution_strategy())
   return "foo_" + s
Exemplo n.º 7
0
    def model_fn():
      v0 = variable_scope.get_variable("var0", [1])
      with variable_scope.variable_scope("common"):
        v1 = variable_scope.get_variable("var1", [1])
        # This will pause the current thread, and execute the other thread.
        distribute_lib.get_tower_context().merge_call(lambda _: _)
        v2 = variable_scope.get_variable(
            "var2", [1],
            synchronization=variable_scope.VariableSynchronization.ON_READ,
            aggregation=variable_scope.VariableAggregation.SUM)
        v3 = variable_scope.get_variable(
            "var3", [1],
            synchronization=variable_scope.VariableSynchronization.ON_WRITE,
            aggregation=variable_scope.VariableAggregation.MEAN)

      return v0, v1, v2, v3
Exemplo n.º 8
0
 def merge_fn(dist, s):
     self.assertIs(distribute._default_distribution_strategy, dist)
     self.assertIs(None, distribute.get_tower_context())
     self.assertIs(dist, distribute.get_cross_tower_context())
     self.assertIs(dist, distribute.get_distribution_strategy())
     self.assertFalse(distribute.has_distribution_strategy())
     return "foo_" + s
Exemplo n.º 9
0
  def _assign_func(self, *args, **kwargs):
    f = kwargs.pop("f")
    if distribute_lib.get_cross_tower_context():
      update_device = distribute_lib.get_update_device()
      # We are calling update on the mirrored variable in cross tower context.
      if update_device is not None:
        # We are calling an assign function on the mirrored variable in cross
        # tower context.
        v = self.get(device=update_device)
        return f(v, *args, **kwargs)

      return distribute_lib.get_distribution_strategy().update(
          self, f, *args, **kwargs)
    else:
      _assert_tower_context()
      # We are calling an assign function on the mirrored variable in tower
      # context.
      # We reduce the value we want to assign/add/sub. More details about how we
      # handle the different use cases can be found in the _reduce method.
      # We call the function on each of the mirrored variables with the reduced
      # value.
      if self._aggregation == vs.VariableAggregation.NONE:
        raise ValueError("You must specify an aggregation method to update a "
                         "MirroredVariable in Tower Context.")

      def merge_fn(strategy, value, *other_args, **other_kwargs):
        return strategy.update(
            self, f,
            strategy.reduce(
                aggregation=self._aggregation, value=value, destinations=self),
            *other_args, **other_kwargs)

      return distribute_lib.get_tower_context().merge_call(merge_fn, *args,
                                                           **kwargs)
    def model_fn():
      v0 = variable_scope.get_variable("var0", [1])
      with variable_scope.variable_scope("common"):
        v1 = variable_scope.get_variable("var1", [1])
        # This will pause the current thread, and execute the other thread.
        distribute_lib.get_tower_context().merge_call(lambda _: _)
        v2 = variable_scope.get_variable(
            "var2", [1],
            synchronization=variable_scope.VariableSynchronization.ON_READ,
            aggregation=variable_scope.VariableAggregation.SUM)
        v3 = variable_scope.get_variable(
            "var3", [1],
            synchronization=variable_scope.VariableSynchronization.ON_WRITE,
            aggregation=variable_scope.VariableAggregation.MEAN)

      return v0, v1, v2, v3
Exemplo n.º 11
0
def _assert_in_default_state(t):
    t.assertIs(distribute._default_tower_context,
               distribute.get_tower_context())
    t.assertIs(None, distribute.get_cross_tower_context())
    t.assertIs(distribute._default_distribution_strategy,
               distribute.get_distribution_strategy())
    t.assertFalse(distribute.has_distribution_strategy())
Exemplo n.º 12
0
def _aggregate_across_towers(metrics_collections, metric_value_fn, *args):
    """Aggregate metric value across towers."""
    def fn(distribution, *a):
        """Call `metric_value_fn` in the correct control flow context."""
        if hasattr(distribution, '_outer_control_flow_context'):
            # If there was an outer context captured before this method was called,
            # then we enter that context to create the metric value op. If the
            # caputred context is `None`, ops.control_dependencies(None) gives the
            # desired behavior. Else we use `Enter` and `Exit` to enter and exit the
            # captured context.
            # This special handling is needed because sometimes the metric is created
            # inside a while_loop (and perhaps a TPU rewrite context). But we don't
            # want the value op to be evaluated every step or on the TPU. So we
            # create it outside so that it can be evaluated at the end on the host,
            # once the update ops have been evaluted.

            # pylint: disable=protected-access
            if distribution._outer_control_flow_context is None:
                with ops.control_dependencies(None):
                    metric_value = metric_value_fn(distribution, *a)
            else:
                distribution._outer_control_flow_context.Enter()
                metric_value = metric_value_fn(distribution, *a)
                distribution._outer_control_flow_context.Exit()
                # pylint: enable=protected-access
        else:
            metric_value = metric_value_fn(distribution, *a)
        if metrics_collections:
            ops.add_to_collections(metrics_collections, metric_value)
        return metric_value

    return distribute_lib.get_tower_context().merge_call(fn, *args)
Exemplo n.º 13
0
def _assert_in_default_state(t):
  t.assertIs(distribute._default_tower_context,
             distribute.get_tower_context())
  t.assertIs(None, distribute.get_cross_tower_context())
  t.assertIs(distribute._default_distribution_strategy,
             distribute.get_distribution_strategy())
  t.assertFalse(distribute.has_distribution_strategy())
Exemplo n.º 14
0
 def model_fn():
     tower_context = distribute_lib.get_tower_context()
     with tower_context.tower_local_var_scope(
             variable_scope.VariableAggregation.SUM):
         v_sum = variable_scope.variable(1.0)
     self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
     return v_sum
Exemplo n.º 15
0
def skip_summary():
    # If using multiple towers in distributed strategy, skip summaries on all
    # towers except the first one (tower_id=0).
    # TODO(priyag): Add a new optional argument that will provide multiple
    # alternatives to override default behavior. (e.g. run on last tower,
    # compute sum or mean across towers).
    tower_context = distribute.get_tower_context()
    return tower_context and tower_context.tower_id > 0
Exemplo n.º 16
0
 def run_fn():
   tower_context = distribute.get_tower_context()
   self.assertTrue(tower_context is not None)
   self.assertIs(None, distribute.get_cross_tower_context())
   self.assertTrue(distribute.has_distribution_strategy())
   self.assertIs(dist, distribute.get_distribution_strategy())
   self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
   self.assertEqual("bar", variable_scope.variable(1.0, name="bar"))
Exemplo n.º 17
0
def skip_summary():
  # If using multiple towers in distributed strategy, skip summaries on all
  # towers except the first one (tower_id=0).
  # TODO(priyag): Add a new optional argument that will provide multiple
  # alternatives to override default behavior. (e.g. run on last tower,
  # compute sum or mean across towers).
  tower_context = distribute.get_tower_context()
  return tower_context and tower_context.tower_id > 0
Exemplo n.º 18
0
 def run_fn():
     tower_context = distribute.get_tower_context()
     self.assertTrue(tower_context is not None)
     self.assertIs(None, distribute.get_cross_tower_context())
     self.assertTrue(distribute.has_distribution_strategy())
     self.assertIs(dist, distribute.get_distribution_strategy())
     self.assertEqual("foo",
                      tower_context.merge_call(None, test_arg="foo"))
     self.assertEqual("bar", variable_scope.variable(1.0, name="bar"))
Exemplo n.º 19
0
 def testScope(self):
     _assert_in_default_state(self)
     dist = _TestStrategy()
     with dist.scope():
         self.assertIs(None, distribute.get_tower_context())
         self.assertIs(dist, distribute.get_cross_tower_context())
         self.assertTrue(distribute.has_distribution_strategy())
         self.assertIs(dist, distribute.get_distribution_strategy())
         self.assertEqual("baz", variable_scope.variable(1.0, name="baz"))
     _assert_in_default_state(self)
Exemplo n.º 20
0
 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, distribute.get_tower_context())
     self.assertIs(dist, distribute.get_cross_tower_context())
     self.assertTrue(distribute.has_distribution_strategy())
     self.assertIs(dist, distribute.get_distribution_strategy())
     self.assertEqual("baz", variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
Exemplo n.º 21
0
 def run_fn():
   tower_context = distribute.get_tower_context()
   self.assertTrue(tower_context is not None)
   self.assertIs(None, distribute.get_cross_tower_context())
   self.assertTrue(distribute.has_distribution_strategy())
   self.assertIs(dist, distribute.get_distribution_strategy())
   self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
   expected_value = _get_test_variable(
       "bar", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="bar"))
Exemplo n.º 22
0
  def set_last_step_output(self, name, output,
                           aggregation=variables_lib.VariableAggregation.NONE):
    """Set `output` with `name` to be outputted from the last step.

    Args:
      name: String, name to identify the output. Doesn't need to match tensor
        name.
      output: The tensors that should be outputted with `name`. See below for
        actual types supported.
      aggregation: Aggregation method to use to aggregate outputs from multiple
        towers. Required if `set_last_step_output` is called in a tower context.
        Optional in cross_tower_context.
        When present, the outputs from all the towers are aggregated using the
        current distribution strategy's `reduce` method. Hence, the type of
        `output` must be what's supported by the corresponding `reduce` method.
        For e.g. if using MirroredStrategy and aggregation is set, output
        must be a `PerDevice` value.
        The aggregation method is also recorded in a dictionary
        `_last_step_outputs_aggregations` for later interpreting of the
        outputs as already reduced or not.

    """
    if distribute_lib.get_cross_tower_context():
      self._last_step_outputs_aggregations[name] = aggregation
      if aggregation is variables_lib.VariableAggregation.NONE:
        self._last_step_outputs[name] = output
      else:
        distribution = distribute_lib.get_distribution_strategy()
        self._last_step_outputs[name] = distribution.reduce(
            aggregation, output, destinations="/device:CPU:0")
    else:
      assert aggregation is not variables_lib.VariableAggregation.NONE
      def merge_fn(distribution, value):
        self._last_step_outputs[name] = distribution.reduce(
            aggregation, value, destinations="/device:CPU:0")
        # Setting this inside the `merge_fn` because all towers share the same
        # context object, so it's more robust to set it only once (even if all
        # the towers are trying to set the same value).
        self._last_step_outputs_aggregations[name] = aggregation
      distribute_lib.get_tower_context().merge_call(merge_fn, output)
Exemplo n.º 23
0
 def run_fn():
     tower_context = distribute.get_tower_context()
     self.assertTrue(tower_context is not None)
     self.assertIs(None, distribute.get_cross_tower_context())
     self.assertTrue(distribute.has_distribution_strategy())
     self.assertIs(dist, distribute.get_distribution_strategy())
     self.assertEqual("foo",
                      tower_context.merge_call(None, test_arg="foo"))
     expected_value = _get_test_variable(
         "bar", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="bar"))
Exemplo n.º 24
0
 def model_fn(device_id):
   tower_context = distribute_lib.get_tower_context()
   with tower_context.tower_local_var_scope("sum"):
     v_sum = variable_scope.variable(1.0)
   with tower_context.tower_local_var_scope("mean"):
     v_mean = variable_scope.variable(4.0)
   self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
   self.assertTrue(isinstance(v_mean, values.TowerLocalVariable))
   updates = [v_sum.assign_add(2.0 + device_id),
              v_mean.assign(6.0 * device_id)]
   all_v_sum[device_id] = v_sum
   all_v_mean[device_id] = v_mean
   return updates, v_sum, v_mean
Exemplo n.º 25
0
 def testScope(self):
     _assert_in_default_state(self)
     dist = _TestStrategy()
     with dist.scope():
         self.assertIs(None, distribute.get_tower_context())
         self.assertIs(dist, distribute.get_cross_tower_context())
         self.assertTrue(distribute.has_distribution_strategy())
         self.assertIs(dist, distribute.get_distribution_strategy())
         expected_value = _get_test_variable(
             "baz", variable_scope.VariableSynchronization.AUTO,
             variable_scope.VariableAggregation.NONE)
         self.assertDictEqual(expected_value,
                              variable_scope.variable(1.0, name="baz"))
     _assert_in_default_state(self)
Exemplo n.º 26
0
 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, distribute.get_tower_context())
     self.assertIs(dist, distribute.get_cross_tower_context())
     self.assertTrue(distribute.has_distribution_strategy())
     self.assertIs(dist, distribute.get_distribution_strategy())
     expected_value = _get_test_variable(
         "baz", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
Exemplo n.º 27
0
    def _decay_weights_op(self, var):
        def apply_decay_fn(v, decay):
            return state_ops.assign_sub(v, decay * v, self._use_locking)

        def merge_fn(strategy, v, value):
            value = strategy.reduce(variable_scope.VariableAggregation.MEAN,
                                    value, v)
            return strategy.update(v, apply_decay_fn, value)

        if not self._decay_var_list or var in self._decay_var_list:
            replica_context = distribute_lib.get_tower_context()
            return replica_context.merge_call(merge_fn, var,
                                              self._weight_decay)
        return control_flow_ops.no_op()
Exemplo n.º 28
0
  def testMergeCall(self):
    _assert_in_default_state(self)

    def merge_fn(dist, s):
      self.assertIs(distribute._default_distribution_strategy, dist)
      self.assertIs(None, distribute.get_tower_context())
      self.assertIs(dist, distribute.get_cross_tower_context())
      self.assertIs(dist, distribute.get_distribution_strategy())
      self.assertFalse(distribute.has_distribution_strategy())
      return "foo_" + s

    tower_ctx = distribute.get_tower_context()
    self.assertIs(distribute._default_tower_context, tower_ctx)
    self.assertEqual("foo_bar", tower_ctx.merge_call(merge_fn, "bar"))
    _assert_in_default_state(self)
Exemplo n.º 29
0
    def testMergeCall(self):
        _assert_in_default_state(self)

        def merge_fn(dist, s):
            self.assertIs(distribute._default_distribution_strategy, dist)
            self.assertIs(None, distribute.get_tower_context())
            self.assertIs(dist, distribute.get_cross_tower_context())
            self.assertIs(dist, distribute.get_distribution_strategy())
            self.assertFalse(distribute.has_distribution_strategy())
            return "foo_" + s

        tower_ctx = distribute.get_tower_context()
        self.assertIs(distribute._default_tower_context, tower_ctx)
        self.assertEqual("foo_bar", tower_ctx.merge_call(merge_fn, "bar"))
        _assert_in_default_state(self)
Exemplo n.º 30
0
 def get(self, device=None):
   """Returns the value for the current device or raises a ValueError."""
   if device is None:
     tower_context = distribute_lib.get_tower_context()
     if tower_context:
       device = tower_context.device
     else:
       device = distribute_lib.get_update_device()
       if device is None:
         device = device_util.current()
   device = device_util.canonicalize(device)
   try:
     return self._index[device]
   except KeyError:
     raise ValueError("Device %s not found in %s (current device %s)" %
                      (device, self._index.keys(), device_util.current()))
Exemplo n.º 31
0
    def decorated(*args):
        """Decorated function with merge_call."""
        tower_context = distribute_lib.get_tower_context()
        if tower_context is None:  # if in cross tower context already
            return result_fn()

        # TODO(psv): Test distribution of metrics using different distribution
        # strategies.

        # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
        # with distribution object as the first parameter. We create a wrapper here
        # so that the result function need not have that parameter.
        def merge_fn_wrapper(distribution, merge_fn, *args):
            # We will get `PerDevice` merge function. Taking the first one as all are
            # identical copies of the function that we had passed below.
            return distribution.unwrap(merge_fn)[0](*args)

        # Wrapping result in merge_call. merge_call is used when we want to leave
        # tower mode and compute a value in cross tower mode.
        return tower_context.merge_call(merge_fn_wrapper, result_fn, *args)
Exemplo n.º 32
0
  def decorated(*args):
    """Decorated function with merge_call."""
    tower_context = distribute_lib.get_tower_context()
    if tower_context is None:  # if in cross tower context already
      return result_fn()

    # TODO(psv): Test distribution of metrics using different distribution
    # strategies.

    # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
    # with distribution object as the first parameter. We create a wrapper here
    # so that the result function need not have that parameter.
    def merge_fn_wrapper(distribution, merge_fn, *args):
      # We will get `PerDevice` merge function. Taking the first one as all are
      # identical copies of the function that we had passed below.
      return distribution.unwrap(merge_fn)[0](*args)

    # Wrapping result in merge_call. merge_call is used when we want to leave
    # tower mode and compute a value in cross tower mode.
    return tower_context.merge_call(merge_fn_wrapper, result_fn, *args)
Exemplo n.º 33
0
 def model_fn(device_id):
   tower_context = distribute_lib.get_tower_context()
   with tower_context.tower_local_var_scope(
       variable_scope.VariableAggregation.SUM):
     v_sum = variable_scope.variable(1.0)
   with tower_context.tower_local_var_scope(
       variable_scope.VariableAggregation.MEAN):
     v_mean = variable_scope.variable(4.0)
   self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
   self.assertTrue(isinstance(v_mean, values.TowerLocalVariable))
   updates = [v_sum.assign_add(2.0 + device_id),
              v_mean.assign(6.0 * device_id)]
   all_v_sum[device_id] = v_sum
   all_v_mean[device_id] = v_mean
   c_sum = v_sum.get()
   c_mean = v_mean.get()
   components_sum[device_id] = c_sum
   components_mean[device_id] = c_mean
   self.assertIsNot(v_sum, c_sum)
   self.assertIsNot(v_mean, c_mean)
   return updates, v_sum, v_mean, c_sum, c_mean
Exemplo n.º 34
0
 def model_fn(name):
   v = variable_scope.variable(1.0, name=name)
   distribute_lib.get_tower_context().merge_call(lambda _: _)
   return v
Exemplo n.º 35
0
 def model_fn(device_id):
   v = variable_scope.variable(1.0, name="foo_" + str(device_id))
   distribute_lib.get_tower_context().merge_call(lambda _: _)
   return v
 def model_fn():
   b = variable_scope.get_variable("b", [1])
   with ops.name_scope("foo"):
     c = distribute_lib.get_tower_context().merge_call(in_cross_tower)
   return b, c
Exemplo n.º 37
0
 def model_fn():
   # This variable should be created only once across the threads because of
   # special variable_creator functions used by `dist.call_for_each_tower`.
   v = variable_scope.variable(1.0, name="foo")
   distribute_lib.get_tower_context().merge_call(lambda _: _)
   return v
 def model_fn():
   value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
                         mirrored_var.dtype)
   return mirrored_var.assign_sub(value)
 def model_fn():
   tower_context = distribute_lib.get_tower_context()
   with tower_context.tower_local_var_scope("sum"):
     v_sum = variable_scope.variable(1.0)
   self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
   return v_sum
 def model_fn():
   b = variable_scope.get_variable("b", [1])
   with ops.name_scope("foo"):
     c = distribute_lib.get_tower_context().merge_call(in_cross_tower)
   return b, c
Exemplo n.º 41
0
 def mark_devices_fn():
   tower_id = distribute_lib.get_tower_context().tower_id
   self.assertLess(tower_id, len(d.worker_devices))
   self.assertFalse(expected_devices[tower_id])
   expected_devices[tower_id] = True
Exemplo n.º 42
0
 def model_fn():
   with ops.name_scope(None, "foo"):
     a = constant_op.constant(1.0, name="a")
     distribute_lib.get_tower_context().merge_call(lambda _: _)
     b = constant_op.constant(2.0, name="b")
   return a, b
Exemplo n.º 43
0
def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
    """Initializes current variables with tensors loaded from given checkpoint.

  Note: This overrides default initialization ops of specified variables and
  redefines dtype.

  Assignment map supports following syntax:

  * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in
    current `scope_name` from `checkpoint_scope_name` with matching tensor
    names.
  * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` -
    will initialize `scope_name/variable_name` variable
    from `checkpoint_scope_name/some_other_variable`.
  * `'scope_variable_name': variable` - will initialize given `tf.Variable`
    object with tensor 'scope_variable_name' from the checkpoint.
  * `'scope_variable_name': list(variable)` - will initialize list of
    partitioned variables with tensor 'scope_variable_name' from the checkpoint.
  * `'/': 'scope_name/'` - will load all variables in current `scope_name` from
    checkpoint's root (e.g. no scope).

  Supports loading into partitioned variables, which are represented as
  `'<variable>/part_<part #>'`.

  Example:

  ```python

  # Say, '/tmp/model.ckpt' has the following tensors:
  #  -- name='old_scope_1/var1', shape=[20, 2]
  #  -- name='old_scope_1/var2', shape=[50, 4]
  #  -- name='old_scope_2/var3', shape=[100, 100]

  # Create new model's variables
  with tf.variable_scope('new_scope_1'):
    var1 = tf.get_variable('var1', shape=[20, 2],
                           initializer=tf.zeros_initializer())
  with tf.variable_scope('new_scope_2'):
    var2 = tf.get_variable('var2', shape=[50, 4],
                           initializer=tf.zeros_initializer())
    # Partition into 5 variables along the first axis.
    var3 = tf.get_variable(name='var3', shape=[100, 100],
                           initializer=tf.zeros_initializer(),
                           partitioner=lambda shape, dtype: [5, 1])

  # Initialize all variables in `new_scope_1` from `old_scope_1`.
  init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1'})

  # Use names to specify which variables to initialize from checkpoint.
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_1/var1': 'new_scope_1/var1',
                        'old_scope_1/var2': 'new_scope_2/var2'})

  # Or use tf.Variable objects to identify what to initialize.
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_1/var1': var1,
                        'old_scope_1/var2': var2})

  # Initialize partitioned variables using variable's name
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_2/var3': 'new_scope_2/var3'})

  # Or specify the list of tf.Variable objects.
  init_from_checkpoint('/tmp/model.ckpt',
                       {'old_scope_2/var3': var3._get_variable_list()})

  ```

  Args:
    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
    assignment_map: Dict, where keys are names of the variables in the
      checkpoint and values are current variables or names of current variables
      (in default graph).

  Raises:
    tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
    ValueError: If missing variables in current graph.
  """
    if distribute_lib.get_cross_tower_context():
        _init_from_checkpoint(None, ckpt_dir_or_file, assignment_map)
    else:
        distribute_lib.get_tower_context().merge_call(_init_from_checkpoint,
                                                      ckpt_dir_or_file,
                                                      assignment_map)
Exemplo n.º 44
0
 def model_fn():
   vs = []
   for i in range(5):
     vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
   distribute_lib.get_tower_context().merge_call(lambda _: _)
   return vs
Exemplo n.º 45
0
def _merge_call_merge_raises_fn():
  distribute_lib.get_tower_context().merge_call(_call_merge_raises_fn)
Exemplo n.º 46
0
 def _add_tower_local_variable(self, *args, **kwargs):
   tower_context = distribute_lib.get_tower_context()
   with tower_context.tower_local_var_scope('mean'):
     return self.add_weight(*args, **kwargs)