def var_key_test():
     self.assertFalse(a._in_graph_mode)
     self.assertFalse(b._in_graph_mode)
     var_key_a = optimizer_v2._var_key(a)
     self.assertStartsWith(var_key_a, 'var_')
     var_key_b = optimizer_v2._var_key(b)
     self.assertStartsWith(var_key_b, 'var_')
     self.assertNotEquals(var_key_a, var_key_b)
 def var_key_test():
   self.assertFalse(a._in_graph_mode)
   self.assertFalse(b._in_graph_mode)
   var_key_a = optimizer_v2._var_key(a)
   self.assertStartsWith(var_key_a, 'var_')
   var_key_b = optimizer_v2._var_key(b)
   self.assertStartsWith(var_key_b, 'var_')
   self.assertNotEquals(var_key_a, var_key_b)
 def testVarKey(self):
     with context.graph_mode():
         a = variables.Variable([1., 2.], name='var')
         b = variables.Variable([1.], name='var')
         self.assertTrue(a._in_graph_mode)
         self.assertTrue(b._in_graph_mode)
         var_key = optimizer_v2._var_key(a)
         self.assertEqual('var', var_key)
         var_key = optimizer_v2._var_key(b)
         self.assertEqual('var_1', var_key)
Beispiel #4
0
 def testVarKey(self):
   with ops.get_default_graph().as_default():
     a = variables.Variable([1., 2.], name='var')
     b = variables.Variable([1.], name='var')
     self.assertTrue(a._in_graph_mode)
     self.assertTrue(b._in_graph_mode)
     var_key = optimizer_v2._var_key(a)
     self.assertEqual('var', var_key)
     var_key = optimizer_v2._var_key(b)
     self.assertEqual('var_1', var_key)
 def testVarKey(self):
   with context.graph_mode():
     a = variables.Variable([1., 2.], name='var')
     b = variables.Variable([1.], name='var')
     self.assertTrue(a._in_graph_mode)
     self.assertTrue(b._in_graph_mode)
     var_key = optimizer_v2._var_key(a)
     self.assertEqual('var', var_key)
     var_key = optimizer_v2._var_key(b)
     self.assertEqual('var_1', var_key)
  def _track_optimizer_slots(self, slots):
    if not all(isinstance(s, TrainableWrapper) for s in slots):
      raise TypeError(
          'Can only track TrainableWrapper slots, but get {}'.format(
              [type(s) for s in slots]))
    identifiers = [optimizer_v2._var_key(s) for s in self._tracked_slots]
    for s in slots:
      if optimizer_v2._var_key(s) not in identifiers:
        self._tracked_slots.append(s)

    if self.params.restrict_policy is not None:
      self.params.restrict_policy._track_params_from_optimizer_slots(slots)
 def add_slot(var, slot_name, initializer="zeros"):
     """Add a new slot variable for `var`."""
     if slot_name not in self._slot_names:
         self._slot_names.append(slot_name)
     var_key = optimizer_v2._var_key(var)
     slot_dict = self._slots.setdefault(var_key, {})
     weight = slot_dict.get(slot_name, None)
     if weight is None:
         if isinstance(initializer,
                       six.string_types) or callable(initializer):
             initializer = initializers.get(initializer)
             initial_value = functools.partial(initializer,
                                               shape=var.shape,
                                               dtype=var.dtype)
         else:
             initial_value = initializer
         strategy = distribute_ctx.get_strategy()
         with strategy.extended.colocate_vars_with(var):
             if isinstance(var, de.TrainableWrapper):
                 weight = de.create_slots(var, initial_value, slot_name,
                                          var._shared_name)
             else:
                 weight = variables.Variable(
                     name="%s/%s" % (var._shared_name, slot_name),  # pylint: disable=protected-access
                     dtype=var.dtype,
                     trainable=False,
                     initial_value=initial_value)
         backend.track_variable(weight)
         slot_dict[slot_name] = weight
         self._restore_slot_variable(slot_name=slot_name,
                                     variable=var,
                                     slot_variable=weight)
         self._weights.append(weight)
     return weight
    def add_slot(var, slot_name, initializer="zeros", shape=None):
        """Add a new slot variable for `var`."""
        if slot_name not in self._slot_names:
            self._slot_names.append(slot_name)
        var_key = optimizer_v2._var_key(var)
        slot_dict = self._slots.setdefault(var_key, {})
        weight = slot_dict.get(slot_name, None)
        if weight is None:
            if isinstance(initializer,
                          six.string_types) or callable(initializer):
                initializer = initializers.get(initializer)
                if isinstance(initializer,
                              trackable.CheckpointInitialValueCallable) or (
                                  shape is not None):
                    slot_shape = shape
                else:
                    slot_shape = var.shape
                initial_value = functools.partial(initializer,
                                                  shape=slot_shape,
                                                  dtype=var.dtype)
            else:
                initial_value = initializer
            with self._distribution_strategy_scope():
                strategy = distribute_ctx.get_strategy()
                if not strategy.extended.variable_created_in_scope(var):
                    raise ValueError(
                        "Trying to create optimizer slot variable under the scope for "
                        "tf.distribute.Strategy ({}), which is different from the scope "
                        "used for the original variable ({}). Make sure the slot "
                        "variables are created under the same strategy scope. This may "
                        "happen if you're restoring from a checkpoint outside the scope"
                        .format(strategy, var))

                with strategy.extended.colocate_vars_with(var):

                    if isinstance(var, de.TrainableWrapper):
                        weight = de.create_slots(var, initial_value, slot_name,
                                                 var._shared_name, self._bp_v2)
                    else:
                        weight = variables.Variable(
                            name="%s/%s" % (
                                var._shared_name,
                                slot_name,
                            ),  # pylint: disable=protected-access
                            dtype=var.dtype,
                            trainable=False,
                            initial_value=initial_value,
                        )
                backend.track_variable(weight)
                slot_dict[slot_name] = weight
                self._restore_slot_variable(slot_name=slot_name,
                                            variable=var,
                                            slot_variable=weight)
                self._weights.append(weight)
        return weight
def create_slots(primary, init, slot_name, op_name, bp_v2):
    """Helper function for creating a slot variable for statefull optimizers."""
    params_var_, params_ids_ = primary.params, primary.ids

    scope_store = variable_scope._get_default_variable_store()
    full_name = params_var_.name + "/" + op_name + "/" + slot_name
    if full_name not in scope_store._vars:
        with ops.colocate_with(primary, ignore_existing=True):
            slot_variable_ = de.Variable(
                name=full_name,
                key_dtype=params_var_.key_dtype,
                value_dtype=params_var_.value_dtype,
                dim=params_var_.dim,
                devices=params_var_.devices,
                partitioner=params_var_.partition_fn,
                initializer=init,
                kv_creator=params_var_.kv_creator,
                trainable=False,
                checkpoint=params_var_.checkpoint,
                bp_v2=bp_v2 if bp_v2 is not None else params_var_.bp_v2,
            )

        scope_store._vars[full_name] = slot_variable_
        # Record the optimizer Variable into trace.
        primary._optimizer_vars.append(slot_variable_)

    slot_trainable = None
    if context.executing_eagerly():
        slot_tw_name = slot_name + '-' + str(optimizer_v2._var_key(primary))
    else:
        # In graph mode of former version, It only uses slot_name as name to
        # trainable wrappers of slots. So here set it the name to slot_name
        # for forward compatibility.
        slot_tw_name = slot_name
    if isinstance(primary, de.shadow_ops.ShadowVariable):
        slot_trainable = de.shadow_ops.ShadowVariable(
            params=scope_store._vars[full_name],
            ids=primary.ids,
            exists=primary.exists,
            name=full_name,
            trainable=False,
        )
    else:
        _, slot_trainable = de.embedding_lookup(
            params=scope_store._vars[full_name],
            ids=params_ids_,
            name=slot_tw_name,
            return_trainable=True,
        )

    return slot_trainable
Beispiel #10
0
    def add_slot(self,
                 var,
                 slot_name,
                 initializer="zeros",
                 manifold_wise=False):
        rank = self.manifold.rank  # rank of tensot of a manifold
        """Add a new slot variable for `var`."""
        if slot_name not in self._slot_names:
            self._slot_names.append(slot_name)
        var_key = opt._var_key(var)
        slot_dict = self._slots.setdefault(var_key, {})
        weight = slot_dict.get(slot_name, None)
        if weight is None:
            if isinstance(initializer,
                          six.string_types) or callable(initializer):
                initializer = initializers.get(initializer)
                if manifold_wise:
                    initial_value = functools.partial(
                        initializer,
                        shape=var.shape[:-rank - 1] + rank * (1, ) + (2, ),
                        dtype=var.dtype)
                else:
                    initial_value = functools.partial(initializer,
                                                      shape=var.shape,
                                                      dtype=var.dtype)
            else:
                initial_value = initializer
            strategy = distribute_ctx.get_strategy()
            if not strategy.extended.variable_created_in_scope(var):
                raise ValueError(
                    "Trying to create optimizer slot variable under the scope for "
                    "tf.distribute.Strategy ({}), which is different from the scope "
                    "used for the original variable ({}). Make sure the slot "
                    "variables are created under the same strategy scope. This may "
                    "happen if you're restoring from a checkpoint outside the scope"
                    .format(strategy, var))

            with strategy.extended.colocate_vars_with(var):
                weight = tf_variables.Variable(
                    name="%s/%s" % (var._shared_name, slot_name),  # pylint: disable=protected-access
                    dtype=var.dtype,
                    trainable=False,
                    initial_value=initial_value)
            backend.track_variable(weight)
            slot_dict[slot_name] = weight
            self._restore_slot_variable(slot_name=slot_name,
                                        variable=var,
                                        slot_variable=weight)
            self._weights.append(weight)
        return weight