Esempio n. 1
0
    def _create_slots(self, var_list):
        """Creates all slots needed by the variables.

    Args:
      var_list: A list of `Variable` objects.
    """
        # We're currently using 3 slots, we could use less.
        for var in var_list:
            self.relevant_vars.add(var)

            # The gradient estimate.
            estimate = slot_creator.create_zeros_slot(var, 'estimate')
            estimate_slots = self._slots.setdefault('estimate', {})
            estimate_slots[_var_key(var)] = estimate

            # The true parameter values (the variables contain shifted parameters).
            true_param = slot_creator.create_slot(var, var.initialized_value(),
                                                  'true_param')
            true_slots = self._slots.setdefault('true_param', {})
            true_slots[_var_key(var)] = true_param

            # Storage for the update of the "apply" optimizer.
            update = slot_creator.create_zeros_slot(var, 'update')
            update_slots = self._slots.setdefault('update', {})
            update_slots[_var_key(var)] = update
Esempio n. 2
0
def BlocksparseAdam(grads, params,
        lr=0.001, decay_mean=0.9, decay_var=0.999, epsilon=1e-8, clip_sigma=0.0, global_step=None, gated=False,
        norm_scale=None, grad_scale=1.0, saturate=0.0, zero_infs=False, zero_nans=False,
        param_qspec=None, mean_qspec=None, var_qspec=None):

    with tf.device("/cpu:0"), tf.variable_scope("adam_lr"):

        if global_step is None:
            t = tf.Variable(initial_value=0.0, name="t", trainable=False)
            t = t.assign_add(1.0)
        else:
            t = tf.cast(global_step.assign_add(1), tf.float32)
        one = tf.constant(1.0)

        lr = lr * tf.sqrt((one - tf.pow(decay_var, t))) /  (one - tf.pow(decay_mean, t))

        if type(grad_scale) is float:
            grad_scale = tf.constant(grad_scale)
        if type(clip_sigma) is float:
            clip_sigma = tf.constant(clip_sigma)

    norm_scale = [] if norm_scale is None else [norm_scale]

    updates = list()
    for grad, param in zip(grads, params):

        mean = slot_creator.create_zeros_slot(param, "adam_mean")
        var  = slot_creator.create_zeros_slot(param, "adam_variance")
        gate = getattr(param, "gate", None)

        colon = param.name.find(":")
        name  = param.name if colon < 0 else param.name[0:colon]

        with tf.device("/gpu:0"), tf.variable_scope("adam/" + name):
            if gated and gate is not None:
                op = adam_gated_op(gate, grad, param, mean, var, lr, grad_scale, clip_sigma, norm_scale,
                        decay_mean=decay_mean, decay_var=decay_var, epsilon=epsilon,
                        saturate=saturate, zero_infs=zero_infs, zero_nans=zero_nans)
            else:
                op = adam_op(grad, param, mean, var, lr, grad_scale, clip_sigma, norm_scale,
                        decay_mean=decay_mean, decay_var=decay_var, epsilon=epsilon,
                        saturate=saturate, zero_infs=zero_infs, zero_nans=zero_nans)

            if param_qspec is not None:
                updates.append(param.assign(quantize(op.out_param, param_qspec, name="param")))
            else:
                updates.append(op.out_param)

            if mean_qspec is not None:
                updates.append(mean.assign(quantize(op.out_mean, mean_qspec, name="mean")))

            if var_qspec is not None:
                updates.append(var.assign(quantize(op.out_var, var_qspec, name="var")))

    return tf.group(*updates)
Esempio n. 3
0
    def _apply_sparse_shared(self, grad, var, indices, scatter_add):
        beta1_weight, beta2_weight = self._get_beta_weights()

        learning_rate_tensor = math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype)
        beta1_tensor = math_ops.cast(self._beta1_tensor, var.dtype.base_dtype)
        beta2_tensor = math_ops.cast(self._beta2_tensor, var.dtype.base_dtype)
        nu1_tensor = math_ops.cast(self._nu1_tensor, var.dtype.base_dtype)
        nu2_tensor = math_ops.cast(self._nu2_tensor, var.dtype.base_dtype)
        epsilon_tensor = math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype)

        beta1_weight = math_ops.cast(beta1_weight, var.dtype.base_dtype) * beta1_tensor + 1.0
        beta2_weight = math_ops.cast(beta2_weight, var.dtype.base_dtype) * beta2_tensor + 1.0

        beta1_adj = 1.0 - (1.0 / beta1_weight)
        beta2_adj = 1.0 - (1.0 / beta2_weight)

        exp_avg = self.get_slot(var, "exp_avg")
        exp_avg_sq = self.get_slot(var, "exp_avg_sq")

        grad_sq = grad * grad

        exp_avg_tensor = state_ops.assign(exp_avg, beta1_adj * exp_avg, use_locking=self._use_locking)
        with ops.control_dependencies([exp_avg_tensor]):
            exp_avg_tensor = scatter_add(exp_avg, indices, (1.0 - beta1_adj) * grad)

        exp_avg_sq_tensor = state_ops.assign(exp_avg_sq, beta2_adj * exp_avg_sq, use_locking=self._use_locking)
        with ops.control_dependencies([exp_avg_sq_tensor]):
            exp_avg_sq_tensor = scatter_add(exp_avg_sq, indices, (1.0 - beta2_adj) * grad_sq)

        avg_grad = slot_creator.create_zeros_slot(var, self._name)
        avg_grad_tensor = state_ops.assign(avg_grad, nu1_tensor * exp_avg_tensor, use_locking=self._use_locking)
        with ops.control_dependencies([avg_grad_tensor]):
            avg_grad_tensor = scatter_add(avg_grad, indices, (1.0 - nu1_tensor) * grad)

        avg_grad_sq = slot_creator.create_zeros_slot(var, self._name)
        avg_grad_sq_tensor = state_ops.assign(
            avg_grad_sq, nu2_tensor * exp_avg_sq_tensor, use_locking=self._use_locking
        )
        with ops.control_dependencies([avg_grad_sq_tensor]):
            avg_grad_sq_tensor = scatter_add(avg_grad_sq, indices, (1.0 - nu2_tensor) * grad_sq)

        avg_grad_rms_tensor = math_ops.sqrt(avg_grad_sq_tensor)

        var_update = state_ops.assign_add(
            var,
            -learning_rate_tensor * avg_grad_tensor / (avg_grad_rms_tensor + epsilon_tensor),
            use_locking=self._use_locking,
        )

        return control_flow_ops.group(*[var_update, exp_avg_tensor, exp_avg_sq_tensor])
Esempio n. 4
0
    def testCreateSlotWithCustomSplitXlaSharding(self):
        # slot_creator is used only in optimizer V1.
        # We insert our own custom split XLA sharding that overrides the SPMD
        # sharding copied over by the slot_creator.
        with ops.Graph().as_default(), self.cached_session():
            v = variables.Variable([1.0, 2.5, 10.0, 15.1], name="var")
            v = xla_sharding.mesh_split(v,
                                        np.array([0, 1]), [0],
                                        use_sharding_op=False)
            with ops.control_dependencies(None):
                slot = slot_creator.create_zeros_slot(v,
                                                      name="slot",
                                                      dtype=dtypes.float64,
                                                      copy_xla_sharding=True)
                slot = xla_sharding.split(slot,
                                          split_dimension=0,
                                          num_devices=4,
                                          use_sharding_op=False)

            self.assertNotEqual(xla_sharding.get_tensor_sharding(v),
                                xla_sharding.get_tensor_sharding(slot))

            slot_sharding = xla_sharding.get_tensor_sharding(slot)
            slot_proto = xla_data_pb2.OpSharding()
            slot_proto.ParseFromString(slot_sharding)
            self.assertEqual(
                slot_proto,
                xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.OTHER,
                                        tile_assignment_dimensions=[4],
                                        tile_assignment_devices=range(4)))
Esempio n. 5
0
    def testCreateSlotWithCustomReplicatedXlaSharding(self):
        # slot_creator is used only in optimizer V1.
        # We insert our own custom replicated XLA sharding that overrides the SPMD
        # sharding copied over by the slot_creator.
        with ops.Graph().as_default(), self.cached_session():
            v = variables.Variable([1.0, 2.5], name="var")
            v = xla_sharding.mesh_split(v,
                                        np.array([0, 1]), [0],
                                        use_sharding_op=False)
            with ops.control_dependencies(None):
                slot = slot_creator.create_zeros_slot(v,
                                                      name="slot",
                                                      dtype=dtypes.float64,
                                                      copy_xla_sharding=True)
                slot = xla_sharding.replicate(slot, use_sharding_op=False)

            self.assertNotEqual(xla_sharding.get_tensor_sharding(v),
                                xla_sharding.get_tensor_sharding(slot))

            slot_sharding = xla_sharding.get_tensor_sharding(slot)
            slot_proto = xla_data_pb2.OpSharding()
            slot_proto.ParseFromString(slot_sharding)
            self.assertEqual(
                slot_proto,
                xla_data_pb2.OpSharding(
                    type=xla_data_pb2.OpSharding.REPLICATED))
Esempio n. 6
0
    def apply(self, var_list=None):
        # TODO(touts): op_scope
        if var_list is None:
            var_list = variables.trainable_variables()
        for var in var_list:
            if var.dtype.base_dtype not in [dtypes.float32, dtypes.float64]:
                raise TypeError(
                    "The variables must be float or double: %s" % var)
            if var in self._averages:
                raise ValueError(
                    "Moving average already computed for: %s" % var)

            # For variables: to lower communication bandwidth across devices we keep
            # the moving averages on the same device as the variables. For other
            # tensors, we rely on the existing device allocation mechanism.
            if isinstance(var, variables.Variable):
                avg = slot_creator.create_slot(
                    var, var.initialized_value(), self._name,
                    colocate_with_primary=True)
            else:
                avg = slot_creator.create_zeros_slot(
                    var, self._name, colocate_with_primary=(var.op.type == "Variable"))
            self._averages[var] = avg

        with ops.name_scope(self._name) as scope:
            decay = self._num_updates / (self._num_updates + 1)
            updates = []
            updates.append(self._num_updates_op)
            for var in var_list:
                updates.append(assign_moving_average(
                    self._averages[var], var, decay))
            return control_flow_ops.group(*updates, name=scope)
    def _zeros_slot(var, slot_name, op_name):
        """Find or create a slot initialized with 0.0.

        Args:
          var: A `Variable` object.
          slot_name: Name for the slot.
          op_name: Name to use when scoping the Variable that
            needs to be created for the slot.

        Returns:
          A `Variable` object.
        """
        named_slots = self._slot_dict(slot_name)
        if optimizer._var_key(var) not in named_slots:
            if isinstance(var, de.TrainableWrapper):
                new_slot_variable = de.create_slots(var, 0.0, slot_name,
                                                    op_name)
            else:
                new_slot_variable = slot_creator.create_zeros_slot(
                    var, op_name)
            self._restore_slot_variable(slot_name=slot_name,
                                        variable=var,
                                        slot_variable=new_slot_variable)
            named_slots[optimizer._var_key(var)] = new_slot_variable
        return named_slots[optimizer._var_key(var)]
Esempio n. 8
0
    def _zero_slots(self, v, slot_name, op_name):
        # get the slot for the accumulative gradient
        name_slot = self._get_slot_in_dict(slot_name)

        # if the variable is not inside name_slot we create it
        if v not in name_slot:
            name_slot[v] = slot_creator.create_zeros_slot(v, op_name)
        return name_slot[v]
Esempio n. 9
0
 def getMakeSlot(self, var, slotName, opName, value=None, zeroesSlot=False):
     namedSlots = self.slots[slotName]
     if var not in namedSlots:
         namedSlots[var] = slot_creator.create_slot(
             var, value,
             opName) if not zeroesSlot else slot_creator.create_zeros_slot(
                 var, opName)
     return namedSlots[var]
    def apply(self, var_list=None):

        if var_list is None:
            var_list = variables.trainable_variables()

        for var in var_list:
            if var.dtype.base_dtype not in [
                    dtypes.float16, dtypes.float32, dtypes.float64
            ]:
                raise TypeError(
                    "The variables must be half, float, or double: %s" %
                    var.name)

            if var not in self._averages:
                # For variables: to lower communication bandwidth across devices we keep
                # the moving averages on the same device as the variables. For other
                # tensors, we rely on the existing device allocation mechanism.
                with ops.init_scope():
                    if isinstance(var, variables.Variable):
                        avg = slot_creator.create_slot(
                            var,
                            var.initialized_value(),
                            self.name,
                            colocate_with_primary=True)
                        # NOTE(mrry): We only add `tf.Variable` objects to the
                        # `MOVING_AVERAGE_VARIABLES` collection.
                        ops.add_to_collection(
                            ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
                    else:
                        avg = slot_creator.create_zeros_slot(
                            var,
                            self.name,
                            colocate_with_primary=(var.op.type in [
                                "Variable", "VariableV2", "VarHandleOp"
                            ]))
                self._averages[var] = avg

        with ops.device('/cpu:0'):
            self._n_models = variable_scope.get_variable(
                shape=[],
                dtype=dtypes.float32,
                name='n_models',
                initializer=init_ops.constant_initializer(0.),
                trainable=False)

        with ops.name_scope(self.name) as scope:
            updates = []
            for var in var_list:
                updates.append(
                    assign_stochastic_average(self._averages[var], var,
                                              self._n_models))
            with ops.control_dependencies(updates):
                update_n_models = state_ops.assign_add(self._n_models,
                                                       1.,
                                                       name=scope)
            return update_n_models
  def testCreateZerosSlotFromVariable(self):
    with self.test_session():
      v = tf.Variable([1.0, 2.5], name="var")
      slot = slot_creator.create_zeros_slot(v, name="slot", dtype=tf.float64)

      tf.initialize_all_variables().run()

      self.assertEqual(slot.op.name, "var/slot")
      self.assertEqual(slot.get_shape().as_list(), [2])
      self.assertEqual(slot.dtype.base_dtype, tf.float64)
      self.assertAllEqual(slot.eval(), [0.0, 0.0])
Esempio n. 12
0
  def testCreateZerosSlotFromTensor(self):
    with self.cached_session():
      v = constant_op.constant([1.0, 2.5], name="const")
      with ops.control_dependencies(None):
        slot = slot_creator.create_zeros_slot(v, name="slot")

      variables.global_variables_initializer().run()

      self.assertEqual("const/slot", slot.op.name)
      self.assertEqual([2], slot.get_shape().as_list())
      self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
      self.assertAllEqual([0.0, 0.0], self.evaluate(slot))
Esempio n. 13
0
    def testCreateZerosSlotFromTensor(self):
        with self.cached_session():
            v = constant_op.constant([1.0, 2.5], name="const")
            with ops.control_dependencies(None):
                slot = slot_creator.create_zeros_slot(v, name="slot")

            self.evaluate(variables.global_variables_initializer())

            self.assertEqual("const/slot", slot.op.name)
            self.assertEqual([2], slot.get_shape().as_list())
            self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
            self.assertAllEqual([0.0, 0.0], self.evaluate(slot))
Esempio n. 14
0
    def testCreateZerosSlotFromTensor(self):
        with self.test_session():
            v = constant_op.constant([1.0, 2.5], name="const")
            with ops.control_dependencies(None):
                slot = slot_creator.create_zeros_slot(v, name="slot")

            variables.global_variables_initializer().run()

            self.assertEqual(slot.op.name, "const/slot")
            self.assertEqual(slot.get_shape().as_list(), [2])
            self.assertEqual(slot.dtype.base_dtype, dtypes.float32)
            self.assertAllEqual(slot.eval(), [0.0, 0.0])
Esempio n. 15
0
  def testCreateZerosSlotFromTensor(self):
    with self.test_session():
      v = tf.constant([1.0, 2.5], name="const")
      with tf.control_dependencies(None):
        slot = slot_creator.create_zeros_slot(v, name="slot")

      tf.initialize_all_variables().run()

      self.assertEqual(slot.op.name, "const/slot")
      self.assertEqual(slot.get_shape().as_list(), [2])
      self.assertEqual(slot.dtype.base_dtype, tf.float32)
      self.assertAllEqual(slot.eval(), [0.0, 0.0])
    def testCreateZerosSlotFromTensor(self):
        with self.test_session():
            v = tf.constant([1.0, 2.5], name="const")

            slot = slot_creator.create_zeros_slot(v, name="slot")

            tf.initialize_all_variables().run()

            self.assertEqual(slot.op.name, "const/slot")
            self.assertEqual(slot.get_shape().as_list(), [2])
            self.assertEqual(slot.dtype.base_dtype, tf.float32)
            self.assertAllEqual(slot.eval(), [0.0, 0.0])
Esempio n. 17
0
  def testCreateZerosSlotFromVariable(self):
    with self.test_session():
      v = tf.Variable([1.0, 2.5], name="var")
      with tf.control_dependencies(None):
        slot = slot_creator.create_zeros_slot(v, name="slot", dtype=tf.float64)

      tf.initialize_all_variables().run()

      self.assertEqual(slot.op.name, "var/slot")
      self.assertEqual(slot.get_shape().as_list(), [2])
      self.assertEqual(slot.dtype.base_dtype, tf.float64)
      self.assertAllEqual(slot.eval(), [0.0, 0.0])
Esempio n. 18
0
  def testCreateZerosSlotFromVariable(self):
    with self.test_session():
      v = variables.Variable([1.0, 2.5], name="var")
      with ops.control_dependencies(None):
        slot = slot_creator.create_zeros_slot(
            v, name="slot", dtype=dtypes.float64)

      variables.global_variables_initializer().run()

      self.assertEqual("var/slot", slot.op.name)
      self.assertEqual([2], slot.get_shape().as_list())
      self.assertEqual(dtypes.float64, slot.dtype.base_dtype)
      self.assertAllEqual([0.0, 0.0], slot.eval())
Esempio n. 19
0
  def testCreateZerosSlotFromVariable(self):
    with self.cached_session():
      v = variables.Variable([1.0, 2.5], name="var")
      with ops.control_dependencies(None):
        slot = slot_creator.create_zeros_slot(
            v, name="slot", dtype=dtypes.float64)

      variables.global_variables_initializer().run()

      self.assertEqual("var/slot", slot.op.name)
      self.assertEqual([2], slot.get_shape().as_list())
      self.assertEqual(dtypes.float64, slot.dtype.base_dtype)
      self.assertAllEqual([0.0, 0.0], self.evaluate(slot))
Esempio n. 20
0
  def testCreateZerosSlotFromDynamicShapedTensor(self):
    with self.cached_session():
      v = random_ops.random_uniform([2], dtype=dtypes.float64)
      v = array_ops.placeholder_with_default(v, shape=[None], name="const")
      with ops.control_dependencies(None):
        slot = slot_creator.create_zeros_slot(
            v, name="slot", dtype=dtypes.float64)

      self.evaluate(variables.global_variables_initializer())

      self.assertEqual("const/slot", slot.op.name)
      self.assertEqual([2], array_ops.shape(slot).eval())
      self.assertEqual(dtypes.float64, slot.dtype.base_dtype)
      self.assertAllEqual([0.0, 0.0], self.evaluate(slot))
Esempio n. 21
0
  def testCreateZerosSlotFromDynamicShapedTensor(self):
    with self.cached_session():
      v = random_ops.random_uniform([2], dtype=dtypes.float64)
      v = array_ops.placeholder_with_default(v, shape=[None], name="const")
      with ops.control_dependencies(None):
        slot = slot_creator.create_zeros_slot(
            v, name="slot", dtype=dtypes.float64)

      variables.global_variables_initializer().run()

      self.assertEqual("const/slot", slot.op.name)
      self.assertEqual([2], array_ops.shape(slot).eval())
      self.assertEqual(dtypes.float64, slot.dtype.base_dtype)
      self.assertAllEqual([0.0, 0.0], self.evaluate(slot))
Esempio n. 22
0
 def testCreateZerosSlotFromVariableCopyXlaSharding(self):
     # slot_creator is used only in optimizer V1.
     with ops.Graph().as_default(), self.cached_session():
         v = variables.Variable([1.0, 2.5], name="var")
         v = xla_sharding.mesh_split(v,
                                     np.array([0, 1]), [0],
                                     use_sharding_op=False)
         with ops.control_dependencies(None):
             slot = slot_creator.create_zeros_slot(v,
                                                   name="slot",
                                                   dtype=dtypes.float64,
                                                   copy_xla_sharding=True)
         self.assertEqual(xla_sharding.get_tensor_sharding(v),
                          xla_sharding.get_tensor_sharding(slot))
Esempio n. 23
0
    def _zeros_slot(self, var, slot_name, op_name):
        """Find or create a slot initialized with 0.0.

    Args:
      var: A `Variable` object.
      slot_name: Name for the slot.
      op_name: Name to use when scoping the Variable that
        needs to be created for  the slot.

    Returns:
      A `Variable` object.
    """
        named_slots = self._slot_dict(slot_name)
        if var not in named_slots:
            named_slots[var] = slot_creator.create_zeros_slot(var, op_name)
        return named_slots[var]
Esempio n. 24
0
    def _zeros_slot(self, var, slot_name, op_name):
        """Find or create a slot initialized with 0.0.

    Args:
      var: A `Variable` object.
      slot_name: Name for the slot.
      op_name: Name to use when scoping the Variable that
        needs to be created for  the slot.

    Returns:
      A `Variable` object.
    """
        named_slots = self._slot_dict(slot_name)
        if var not in named_slots:
            named_slots[var] = slot_creator.create_zeros_slot(var, op_name)
        return named_slots[var]
Esempio n. 25
0
  def testCreateZerosSlotFromDynamicShapedVariable(self):
    with self.cached_session():
      dyn_shape = constant_op.constant([2], dtype=dtypes.int32)
      dyn_shape = array_ops.placeholder_with_default(dyn_shape,
                                                     shape=[None])
      v = variable_scope.get_variable(
          "var",
          initializer=random_ops.random_uniform(dyn_shape,
                                                dtype=dtypes.float64),
          validate_shape=False)
      with ops.control_dependencies(None):
        slot = slot_creator.create_zeros_slot(
            v, name="slot", dtype=dtypes.float64)

      self.evaluate(variables.global_variables_initializer())

      self.assertEqual("var/slot", slot.op.name)
      self.assertEqual([2], array_ops.shape(slot).eval())
      self.assertEqual(dtypes.float64, slot.dtype.base_dtype)
      self.assertAllEqual([0.0, 0.0], self.evaluate(slot))
Esempio n. 26
0
  def _zeros_slot(self, var, slot_name, op_name):
    """Find or create a slot initialized with 0.0.

    Args:
      var: A `Variable` object.
      slot_name: Name for the slot.
      op_name: Name to use when scoping the Variable that
        needs to be created for the slot.

    Returns:
      A `Variable` object.
    """
    named_slots = self._slot_dict(slot_name)
    if _var_key(var) not in named_slots:
      new_slot_variable = slot_creator.create_zeros_slot(var, op_name)
      self._restore_slot_variable(
          slot_name=slot_name, variable=var,
          slot_variable=new_slot_variable)
      named_slots[_var_key(var)] = new_slot_variable
    return named_slots[_var_key(var)]
Esempio n. 27
0
  def testCreateZerosSlotFromDynamicShapedVariable(self):
    with self.cached_session():
      dyn_shape = constant_op.constant([2], dtype=dtypes.int32)
      dyn_shape = array_ops.placeholder_with_default(dyn_shape,
                                                     shape=[None])
      v = variable_scope.get_variable(
          "var",
          initializer=random_ops.random_uniform(dyn_shape,
                                                dtype=dtypes.float64),
          validate_shape=False)
      with ops.control_dependencies(None):
        slot = slot_creator.create_zeros_slot(
            v, name="slot", dtype=dtypes.float64)

      variables.global_variables_initializer().run()

      self.assertEqual("var/slot", slot.op.name)
      self.assertEqual([2], array_ops.shape(slot).eval())
      self.assertEqual(dtypes.float64, slot.dtype.base_dtype)
      self.assertAllEqual([0.0, 0.0], self.evaluate(slot))
Esempio n. 28
0
  def _zeros_slot(self, var, slot_name, op_name):
    """Find or create a slot initialized with 0.0.

    Args:
      var: A `Variable` object.
      slot_name: Name for the slot.
      op_name: Name to use when scoping the Variable that
        needs to be created for the slot.

    Returns:
      A `Variable` object.
    """
    named_slots = self._slot_dict(slot_name)
    if _var_key(var) not in named_slots:
      new_slot_variable = slot_creator.create_zeros_slot(var, op_name)
      self._restore_slot_variable(
          slot_name=slot_name, variable=var,
          slot_variable=new_slot_variable)
      named_slots[_var_key(var)] = new_slot_variable
    return named_slots[_var_key(var)]
Esempio n. 29
0
        def _zeros_slot(self, var, slot_name, op_name):
            """Find or create a slot initialized with 0.0.
            This is effectively a copy of the original TF optimizer method
            excepts this one allows to pass a dtype to `create_zeros_slot`.
            Args:
              var: A `Variable` object.
              slot_name: Name for the slot.
              op_name: Name to use when scoping the Variable that
                needs to be created for the slot.
            Returns:
              A `Variable` object.
            """
            named_slots = self._slot_dict(slot_name)
            if _var_key(var) not in named_slots:
                new_slot_variable = slot_creator.create_zeros_slot(var, op_name,
                                                                   dtype=self.slots_dtype)
                self._restore_slot_variable(
                    slot_name=slot_name, variable=var,
                    slot_variable=new_slot_variable)
                named_slots[_var_key(var)] = new_slot_variable

            return tf.cast(named_slots[_var_key(var)], var.dtype)
Esempio n. 30
0
  def apply(self, var_list=None):
    """Maintains moving averages of variables.

    `var_list` must be a list of `Variable` or `Tensor` objects.  This method
    creates shadow variables for all elements of `var_list`.  Shadow variables
    for `Variable` objects are initialized to the variable's initial value.
    They will be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection.
    For `Tensor` objects, the shadow variables are initialized to 0 and zero
    debiased (see docstring in `assign_moving_average` for more details).

    shadow variables are created with `trainable=False` and added to the
    `GraphKeys.ALL_VARIABLES` collection.  They will be returned by calls to
    `tf.global_variables()`.

    Returns an op that updates all shadow variables from the current value of
    their associated variables.

    Note that `apply()` can be called multiple times. When eager execution is
    enabled each call to apply will update the variables once, so this needs to
    be called in a loop.

    Args:
      var_list: A list of Variable or Tensor objects. The variables
        and Tensors must be of types bfloat16, float16, float32, or float64.

    Returns:
      An Operation that updates the moving averages.

    Raises:
      TypeError: If the arguments are not an allowed type.
    """
    # TODO(touts): op_scope
    if var_list is None:
      var_list = variables.trainable_variables()
    zero_debias_true = set()  # set of vars to set `zero_debias=True`
    for var in var_list:
      if var.dtype.base_dtype not in [
          dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64
      ]:
        raise TypeError("The variables must be half, float, or double: %s" %
                        var.name)

      if var not in self._averages:
        # For variables: to lower communication bandwidth across devices we keep
        # the moving averages on the same device as the variables. For other
        # tensors, we rely on the existing device allocation mechanism.
        with ops.init_scope():
          if isinstance(var, variables.Variable):
            avg = slot_creator.create_slot(var,
                                           var.initialized_value(),
                                           self.name,
                                           colocate_with_primary=True)
            # NOTE(mrry): We only add `tf.Variable` objects to the
            # `MOVING_AVERAGE_VARIABLES` collection.
            ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
          else:
            avg = slot_creator.create_zeros_slot(
                var,
                self.name,
                colocate_with_primary=(var.op.type in ["Variable",
                                                       "VariableV2",
                                                       "VarHandleOp"]))
            if self._zero_debias:
              zero_debias_true.add(avg)
        self._averages[var] = avg

    with ops.name_scope(self.name) as scope:
      decay = ops.convert_to_tensor(self._decay, name="decay")
      if self._num_updates is not None:
        num_updates = math_ops.cast(self._num_updates,
                                    dtypes.float32,
                                    name="num_updates")
        decay = math_ops.minimum(decay,
                                 (1.0 + num_updates) / (10.0 + num_updates))
      updates = []
      for var in var_list:
        zero_debias = self._averages[var] in zero_debias_true
        updates.append(assign_moving_average(
            self._averages[var], var, decay, zero_debias=zero_debias))
      return control_flow_ops.group(*updates, name=scope)
Esempio n. 31
0
    def apply(self, var_list=None):
        """Maintains moving averages of variables.

    `var_list` must be a list of `Variable` or `Tensor` objects.  This method
    creates shadow variables for all elements of `var_list`.  Shadow variables
    for `Variable` objects are initialized to the variable's initial value.
    They will be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection.
    For `Tensor` objects, the shadow variables are initialized to 0 and zero
    debiased (see docstring in `assign_moving_average` for more details).

    shadow variables are created with `trainable=False` and added to the
    `GraphKeys.ALL_VARIABLES` collection.  They will be returned by calls to
    `tf.global_variables()`.

    Returns an op that updates all shadow variables as described above.

    Note that `apply()` can be called multiple times with different lists of
    variables.

    Args:
      var_list: A list of Variable or Tensor objects. The variables
        and Tensors must be of types float16, float32, or float64.

    Returns:
      An Operation that updates the moving averages.

    Raises:
      TypeError: If the arguments are not all float16, float32, or float64.
      ValueError: If the moving average of one of the variables is already
        being computed.
    """
        # TODO(touts): op_scope
        if var_list is None:
            var_list = variables.trainable_variables()
        zero_debias_true = set()  # set of vars to set `zero_debias=True`
        for var in var_list:
            if var.dtype.base_dtype not in [
                    dtypes.float16, dtypes.float32, dtypes.float64
            ]:
                raise TypeError(
                    "The variables must be half, float, or double: %s" %
                    var.name)
            if var in self._averages:
                raise ValueError("Moving average already computed for: %s" %
                                 var.name)

            # For variables: to lower communication bandwidth across devices we keep
            # the moving averages on the same device as the variables. For other
            # tensors, we rely on the existing device allocation mechanism.
            with ops.init_scope():
                if isinstance(var, variables.Variable):
                    avg = slot_creator.create_slot(var,
                                                   var.initialized_value(),
                                                   self._name,
                                                   colocate_with_primary=True)
                    # NOTE(mrry): We only add `tf.Variable` objects to the
                    # `MOVING_AVERAGE_VARIABLES` collection.
                    ops.add_to_collection(
                        ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
                else:
                    avg = slot_creator.create_zeros_slot(
                        var,
                        self._name,
                        colocate_with_primary=(var.op.type in [
                            "Variable", "VariableV2", "VarHandleOp"
                        ]))
                    if self._zero_debias:
                        zero_debias_true.add(avg)
            self._averages[var] = avg

        with ops.name_scope(self._name) as scope:
            num_updates = math_ops.cast(self._num_updates,
                                        dtypes.float32,
                                        name="num_updates")
            decay = num_updates / (1. + num_updates)
            decay = array_ops.identity(decay, name='decay')

            updates = []
            for var in var_list:
                zero_debias = self._averages[var] in zero_debias_true
                updates.append(
                    assign_moving_average(self._averages[var],
                                          var,
                                          decay,
                                          zero_debias=zero_debias))
            return control_flow_ops.group(*updates, name=scope)
Esempio n. 32
0
 def _create_q(self, d_oo_d_state):
     self._qs.append(slot_creator.create_zeros_slot(d_oo_d_state, 'q'))
     return self._qs[-1]
Esempio n. 33
0
  def apply(self, var_list=None):
    """Maintains moving averages of variables.

    `var_list` must be a list of `Variable` objects.  This method
    creates shadow variables (holding the moving averages)
    for all elements of `var_list`, and
    updates the moving averages using the current `var_list` values. Shadow
    variables for `Variable` objects are initialized to the variable's initial
    value.

    Shadow variables are created with `trainable=False`. To access them you
    can use the EMA object's `average` method. Note that `EMA` objects are
    not trackable by checkpoints, so if you want to checkpoint or restore the
    moving variables you will need to manually grab the shadow
    variables via `average()` and assign them as `tf.Module` properties or
    directly pass them to your `tf.train.Checkpoint`.

    Note that `apply()` can be called multiple times. When eager execution is
    enabled each call to apply will update the variables once, so this needs to
    be called in a loop.

    In legacy TF 1.x graphs, this method returns an op that updates all
    shadow variables from the current value of their associated variables. In
    TF 1.x graphs without automatically control dependencies this op needs to be
    manually run.

    Args:
      var_list: A list of Variable objects. The variables
        must be of types bfloat16, float16, float32, or float64.
        (In legacy TF 1.x graphs these may be tensors, but this is unsupported
        when eager execution is enabled.)

    Returns:
      An Operation that updates the moving averages.

    Raises:
      TypeError: If the arguments are not an allowed type.
    """
    # TODO(touts): op_scope
    if var_list is None:
      var_list = variables.trainable_variables()
    for v in var_list:
      if (isinstance(v, ops.Tensor)
          and ops.executing_eagerly_outside_functions()):
        raise TypeError(
            "tf.train.ExponentialMovingAverage does not support non-Variable"
            " tensors when eager execution is enabled.")
    zero_debias_true = set()  # set of vars to set `zero_debias=True`
    for var in var_list:
      if var.dtype.base_dtype not in [
          dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64
      ]:
        raise TypeError("The variables must be half, float, or double: %s" %
                        var.name)

      if var.ref() not in self._averages:
        # For variables: to lower communication bandwidth across devices we keep
        # the moving averages on the same device as the variables. For other
        # tensors, we rely on the existing device allocation mechanism.
        with ops.init_scope():
          if isinstance(var, variables.Variable):
            with ops.device(var.device):
              initialized_value = var.initialized_value()
            avg = slot_creator.create_slot(
                var,
                initialized_value,
                self.name,
                colocate_with_primary=True,
                copy_xla_sharding=True)
            # NOTE(mrry): We only add `tf.Variable` objects to the
            # `MOVING_AVERAGE_VARIABLES` collection.
            ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
          else:
            avg = slot_creator.create_zeros_slot(
                var,
                self.name,
                colocate_with_primary=(var.op.type in [
                    "Variable", "VariableV2", "VarHandleOp"
                ]),
                copy_xla_sharding=True)
            if self._zero_debias:
              zero_debias_true.add(avg.ref())
        self._averages[var.ref()] = avg

    with ops.name_scope(self.name) as scope:
      decay = ops.convert_to_tensor(
          self._decay, dtype=dtypes.float32, name="decay")
      if self._num_updates is not None:
        num_updates = math_ops.cast(
            self._num_updates, dtypes.float32, name="num_updates")
        decay = math_ops.minimum(decay,
                                 (1.0 + num_updates) / (10.0 + num_updates))
      updates = []
      for var in var_list:
        avg = self._averages[var.ref()]
        zero_debias = avg.ref() in zero_debias_true
        updates.append(assign_moving_average(avg, var, decay, zero_debias))
      return control_flow_ops.group(*updates, name=scope)
 def _zeros_slot(self, var, slot_name, op_name):
     named_slots = self._slot_dict(slot_name)
     if var not in named_slots:
         named_slots[var] = slot_creator.create_zeros_slot(var, op_name)
     return named_slots[var]
Esempio n. 35
0
    def compute_gradients(self,
                          outer_objective,
                          optimizer_dict: OptimizerDict,
                          hyper_list=None,
                          clip_value=100.):
        """
        This methods populates the computational graph

        :param outer_objective:  optimization objective for the learning rate (e.g. validation error)
        :param optimizer_dict: on `OptimizerDict` (see marthe.optimizers) e.g.
                                 opt_dict = marthe.GradientDescentOptimizer(lr).minimize(training_error)
        :param hyper_list:   an optional list of hyperparameters (the learning rate). By default all variables
                                in the collection `HYPERPARAMETERS`  are taken
        :param clip_value: optional value for clipping the Marthe lr update
        """
        assert isinstance(
            optimizer_dict,
            OptimizerDict), _ERROR_NOT_OPTIMIZER_DICT.format(optimizer_dict)
        self._opt_dict = optimizer_dict

        if hyper_list is None:  # get default hyperparameters
            hyper_list = utils.hyperparameters(tf.get_variable_scope().name)

        state = list(optimizer_dict.state)
        print('HG - COMPUTE GRADIENTS - LEN STATE: {}'.format(len(state)))

        vs = [tf.ones_like(w) for w in state]  # `ghost variables'

        vec_vs = utils.vectorize_all(vs)
        print('HG - COMPUTE GRADIENTS - TOTAL PARAMETERS: {}'.format(
            vec_vs.get_shape().as_list()))
        dynamics = list(optimizer_dict.dynamics)
        vec_dynamics = utils.vectorize_all(dynamics)

        outer_obj_grads = grad(outer_objective, state)
        self.vec_outer_obj_grads = utils.vectorize_all(
            outer_obj_grads)  # also used for heuristics

        self._w_dots = [[
            slot_creator.create_zeros_slot(w,
                                           name='w_dot_{}'.format(h.op.name))
            for w in state
        ] for h in hyper_list]
        vec_w_dots = [utils.vectorize_all(w_dot) for w_dot in self._w_dots]

        for hyper, w_dot, vec_w_dot in zip(hyper_list, self._w_dots,
                                           vec_w_dots):
            assert hyper.shape.ndims == 0

            A_w_dot = grad(utils.dot(vec_dynamics, vec_w_dot), state)
            B = grad(grad(utils.dot(vec_dynamics, vec_vs), hyper)[0], vs)

            self.Bs.append(utils.vectorize_all(B))  # used in the heuristics

            mu = tf.convert_to_tensor(self.mu_pl, dtype=A_w_dot[0].dtype)

            if self.mu == 'adapt' or self.mu > 0:
                self._w_dots_iterations.append([
                    wd.assign(mu * awd + b)
                    for wd, awd, b in zip(w_dot, A_w_dot, B)
                ])
            else:
                self._w_dots_iterations.append(
                    [wd.assign(b) for wd, awd, b in zip(w_dot, A_w_dot, B)])

            hg = clip_and_count(utils.dot(self.vec_outer_obj_grads, vec_w_dot),
                                self.hg_clip_counter, clip_value)
            # tf.add_to_collection(hg, utils.GraphKeys.HYPERGRADIENTS)

            # todo ADD d E / d lambda when required
            self._hypergrads.append(hg)
            self.hypergrads[hyper] = hg
            self._hyper_list = hyper_list

        def _apply_hg():
            print(self._hypergrads, self._hyper_list)
            return self._outer_object_optimizer.apply_gradients(
                list(zip(self._hypergrads, self._hyper_list)),
                global_step=self.gs)

        with tf.control_dependencies([_apply_hg()]):
            with tf.control_dependencies(self._w_dots_iterations[0]):
                self.step = self._opt_dict.iteration  # hopefully this still must be compiled... otherwise with these
Esempio n. 36
0
  def apply(self, var_list=None):
    """Maintains moving averages of variables.

    `var_list` must be a list of `Variable` or `Tensor` objects.  This method
    creates shadow variables for all elements of `var_list`.  Shadow variables
    for `Variable` objects are initialized to the variable's initial value.
    They will be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection.
    For `Tensor` objects, the shadow variables are initialized to 0.

    shadow variables are created with `trainable=False` and added to the
    `GraphKeys.ALL_VARIABLES` collection.  They will be returned by calls to
    `tf.all_variables()`.

    Returns an op that updates all shadow variables as described above.

    Note that `apply()` can be called multiple times with different lists of
    variables.

    Args:
      var_list: A list of Variable or Tensor objects. The variables
        and Tensors must be of types float32 or float64.

    Returns:
      An Operation that updates the moving averages.

    Raises:
      TypeError: If the arguments are not all float32 or float64.
      ValueError: If the moving average of one of the variables is already
        being computed.
    """
    # TODO(touts): op_scope
    if var_list is None:
      var_list = variables.trainable_variables()
    for var in var_list:
      if var.dtype.base_dtype not in [dtypes.float32, dtypes.float64]:
        raise TypeError("The variables must be float or double: %s" % var.name)
      if var in self._averages:
        raise ValueError("Moving average already computed for: %s" % var.name)

      # For variables: to lower communication bandwidth across devices we keep
      # the moving averages on the same device as the variables. For other
      # tensors, we rely on the existing device allocation mechanism.
      with ops.control_dependencies(None):
        if isinstance(var, variables.Variable):
          avg = slot_creator.create_slot(
              var, var.initialized_value(), self._name,
              colocate_with_primary=True)
        else:
          avg = slot_creator.create_zeros_slot(
              var, self._name,
              colocate_with_primary=(var.op.type == "Variable"))
      self._averages[var] = avg
      ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)

    with ops.name_scope(self._name) as scope:
      decay = ops.convert_to_tensor(self._decay, name="decay")
      if self._num_updates is not None:
        num_updates = math_ops.cast(self._num_updates, dtypes.float32,
                                    name="num_updates")
        decay = math_ops.minimum(decay,
                                 (1.0 + num_updates) / (10.0 + num_updates))
      updates = []
      for var in var_list:
        updates.append(assign_moving_average(self._averages[var], var, decay))
      return control_flow_ops.group(*updates, name=scope)
Esempio n. 37
0
    def apply(self, var_list=None):
        """Maintains moving averages of variables.

    `var_list` must be a list of `Variable` or `Tensor` objects.  This method
    creates shadow variables for all elements of `var_list`.  Shadow variables
    for `Variable` objects are initialized to the variable's initial value.
    They will be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection.
    For `Tensor` objects, the shadow variables are initialized to 0.

    shadow variables are created with `trainable=False` and added to the
    `GraphKeys.ALL_VARIABLES` collection.  They will be returned by calls to
    `tf.all_variables()`.

    Returns an op that updates all shadow variables as described above.

    Note that `apply()` can be called multiple times with different lists of
    variables.

    Args:
      var_list: A list of Variable or Tensor objects. The variables
        and Tensors must be of types float32 or float64.

    Returns:
      An Operation that updates the moving averages.

    Raises:
      TypeError: If the arguments are not all float32 or float64.
      ValueError: If the moving average of one of the variables is already
        being computed.
    """
        # TODO(touts): op_scope
        if var_list is None:
            var_list = variables.trainable_variables()
        for var in var_list:
            if var.dtype.base_dtype not in [dtypes.float32, dtypes.float64]:
                raise TypeError("The variables must be float or double: %s" %
                                var.name)
            if var in self._averages:
                raise ValueError("Moving average already computed for: %s" %
                                 var.name)

            # For variables: to lower communication bandwidth across devices we keep
            # the moving averages on the same device as the variables. For other
            # tensors, we rely on the existing device allocation mechanism.
            with ops.control_dependencies(None):
                if isinstance(var, variables.Variable):
                    avg = slot_creator.create_slot(var,
                                                   var.initialized_value(),
                                                   self._name,
                                                   colocate_with_primary=True)
                    # NOTE(mrry): We only add `tf.Variable` objects to the
                    # `MOVING_AVERAGE_VARIABLES` collection.
                    ops.add_to_collection(
                        ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
                else:
                    avg = slot_creator.create_zeros_slot(
                        var,
                        self._name,
                        colocate_with_primary=(var.op.type == "Variable"))
            self._averages[var] = avg

        with ops.name_scope(self._name) as scope:
            decay = ops.convert_to_tensor(self._decay, name="decay")
            if self._num_updates is not None:
                num_updates = math_ops.cast(self._num_updates,
                                            dtypes.float32,
                                            name="num_updates")
                decay = math_ops.minimum(decay, (1.0 + num_updates) /
                                         (10.0 + num_updates))
            updates = []
            for var in var_list:
                updates.append(
                    assign_moving_average(self._averages[var], var, decay))
            return control_flow_ops.group(*updates, name=scope)
    def __init__(self, model, decay, weights_list=None, temp_model='temp_model.h5',
                 name='ExponentialMovingAverage', type='cpu'):
        # EMA for keras, the example can be seen in https://github.com/ewrfcas/QANet_keras/blob/master/train_QANet.py
        # init before training, but after the model init.
        self.model = model
        self.scope_name = name
        self.temp_model = temp_model
        self.type = type
        self.decay = decay
        self._averages = {}

        if weights_list is None:
            weights_list = self.model.trainable_weights

        if self.type == 'gpu':
            self.sess = K.get_session()
            for weight in weights_list:
                if weight.dtype.base_dtype not in [tf.float16, tf.float32,
                                                   tf.float64]:
                    raise TypeError("The variables must be half, float, or double: %s" %
                                    weight.name)
                if weight in self._averages:
                    raise ValueError("Moving average already computed for: %s" % weight.name)

                # For variables: to lower communication bandwidth across devices we keep
                # the moving averages on the same device as the variables. For other
                # tensors, we rely on the existing device allocation mechanism.
                with ops.init_scope():
                    if isinstance(weight, tf.Variable):
                        avg = slot_creator.create_slot(weight,
                                                       weight.initialized_value(),
                                                       self.scope_name,
                                                       colocate_with_primary=True)
                        # NOTE(mrry): We only add `tf.Variable` objects to the
                        # `MOVING_AVERAGE_VARIABLES` collection.
                        ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, weight)
                    else:
                        avg = slot_creator.create_zeros_slot(weight,
                                                             self.scope_name,
                                                             colocate_with_primary=(weight.op.type in ["Variable",
                                                                                                       "VariableV2",
                                                                                                       "VarHandleOp"]))
                self._averages[weight] = avg

            with tf.name_scope(self.scope_name):
                decay = ops.convert_to_tensor(decay, name="decay")
                self.updates = []
                for var in weights_list:
                    self.updates.append(
                        moving_averages.assign_moving_average(self._averages[var], var, decay, zero_debias=False))

                self.assigns = []
                for weight in weights_list:
                    self.assigns.append(tf.assign(weight, self._averages[weight]))

            self.sess.run(tf.global_variables_initializer())

        elif self.type == 'cpu':
            print('CPU EMA getting weights...')
            for weight in tqdm(weights_list):
                self._averages[weight.name] = K.get_value(weight)
Esempio n. 39
0
 def _zeros_slot(self, var, slot_name, op_name):
   named_slots = self._slot_dict(slot_name)
   if var not in named_slots:
     named_slots[var] = slot_creator.create_zeros_slot(var, op_name)
   return named_slots[var]
Esempio n. 40
0
    def apply(self, var_list=None):
        """Maintains moving averages of variables.

    `var_list` must be a list of `Variable` or `Tensor` objects.  This method
    creates shadow variables for all elements of `var_list`.  Shadow variables
    for `Variable` objects are initialized to the variable's initial value.
    They will be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection.
    For `Tensor` objects, the shadow variables are initialized to 0 and zero
    debiased (see docstring in `assign_moving_average` for more details).

    shadow variables are created with `trainable=False` and added to the
    `GraphKeys.ALL_VARIABLES` collection.  They will be returned by calls to
    `tf.compat.v1.global_variables()`.

    Returns an op that updates all shadow variables from the current value of
    their associated variables.

    Note that `apply()` can be called multiple times. When eager execution is
    enabled each call to apply will update the variables once, so this needs to
    be called in a loop.

    Args:
      var_list: A list of Variable or Tensor objects. The variables and Tensors
        must be of types bfloat16, float16, float32, or float64.

    Returns:
      An Operation that updates the moving averages.

    Raises:
      TypeError: If the arguments are not an allowed type.
    """
        # TODO(touts): op_scope
        if var_list is None:
            var_list = variables.trainable_variables()
        for v in var_list:
            if isinstance(v, ops.EagerTensor):
                raise TypeError(
                    "tf.train.ExponentialMovingAverage does not support non-Variable"
                    " tensors when eager execution is enabled.")
        zero_debias_true = set()  # set of vars to set `zero_debias=True`
        for var in var_list:
            if var.dtype.base_dtype not in [
                    dtypes.bfloat16, dtypes.float16, dtypes.float32,
                    dtypes.float64
            ]:
                raise TypeError(
                    "The variables must be half, float, or double: %s" %
                    var.name)

            if var.ref() not in self._averages:
                # For variables: to lower communication bandwidth across devices we keep
                # the moving averages on the same device as the variables. For other
                # tensors, we rely on the existing device allocation mechanism.
                with ops.init_scope():
                    if isinstance(var, variables.Variable):
                        with ops.device(var.device):
                            initialized_value = var.initialized_value()
                        avg = slot_creator.create_slot(
                            var,
                            initialized_value,
                            self.name,
                            colocate_with_primary=True,
                            copy_xla_sharding=True)
                        # NOTE(mrry): We only add `tf.Variable` objects to the
                        # `MOVING_AVERAGE_VARIABLES` collection.
                        ops.add_to_collection(
                            ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
                    else:
                        avg = slot_creator.create_zeros_slot(
                            var,
                            self.name,
                            colocate_with_primary=(var.op.type in [
                                "Variable", "VariableV2", "VarHandleOp"
                            ]),
                            copy_xla_sharding=True)
                        if self._zero_debias:
                            zero_debias_true.add(avg.ref())
                self._averages[var.ref()] = avg

        with ops.name_scope(self.name) as scope:
            decay = ops.convert_to_tensor(self._decay,
                                          dtype=dtypes.float32,
                                          name="decay")
            if self._num_updates is not None:
                num_updates = math_ops.cast(self._num_updates,
                                            dtypes.float32,
                                            name="num_updates")
                decay = math_ops.minimum(decay, (1.0 + num_updates) /
                                         (10.0 + num_updates))
            updates = []
            for var in var_list:
                avg = self._averages[var.ref()]
                zero_debias = avg.ref() in zero_debias_true
                updates.append(
                    assign_moving_average(avg, var, decay, zero_debias))
            return control_flow_ops.group(*updates, name=scope)