示例#1
0
    def _resource_apply_sparse(self, grad, var, indices):
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)
        beta_1_t = self._get_hyper('beta_1', var_dtype)
        beta_2_t = self._get_hyper('beta_2', var_dtype)
        local_step = math_ops.cast(self.iterations + 1, var_dtype)
        beta_1_power = math_ops.pow(beta_1_t, local_step)
        beta_2_power = math_ops.pow(beta_2_t, local_step)
        epsilon_t = self._get_hyper('epsilon', var_dtype)
        lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))

        # \\(m := beta1 * m + (1 - beta1) * g_t\\)
        m = self.get_slot(var, "m")
        m_t_slice = beta_1_t * array_ops.gather(
            m, indices) + (1 - beta_1_t) * grad
        m_update_op = resource_variable_ops.resource_scatter_update(
            m.handle, indices, m_t_slice)

        # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
        v = self.get_slot(var, "v")
        v_t_slice = (beta_2_t * array_ops.gather(v, indices) +
                     (1 - beta_2_t) * math_ops.square(grad))
        v_update_op = resource_variable_ops.resource_scatter_update(
            v.handle, indices, v_t_slice)

        # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
        var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t)
        var_update_op = resource_variable_ops.resource_scatter_sub(
            var.handle, indices, var_slice)

        return control_flow_ops.group(
            *[var_update_op, m_update_op, v_update_op])
示例#2
0
  def _resource_apply_sparse(self, grad, var, indices):
    beta1_power, beta2_power = self._get_beta_accumulators()
    beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
    beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
    lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
    beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
    beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
    epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
    lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))

    # \\(m := beta1 * m + (1 - beta1) * g_t\\)
    m = self.get_slot(var, "m")
    m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad
    m_update_op = resource_variable_ops.resource_scatter_update(m.handle,
                                                                indices,
                                                                m_t_slice)

    # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
    v = self.get_slot(var, "v")
    v_t_slice = (beta2_t * array_ops.gather(v, indices) +
                 (1 - beta2_t) * math_ops.square(grad))
    v_update_op = resource_variable_ops.resource_scatter_update(v.handle,
                                                                indices,
                                                                v_t_slice)

    # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
    var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t)
    var_update_op = resource_variable_ops.resource_scatter_sub(var.handle,
                                                               indices,
                                                               var_slice)

    return control_flow_ops.group(var_update_op, m_update_op, v_update_op)
  def _resource_apply_sparse(self, grad, var, indices):
    beta1_power, beta2_power = self._get_beta_accumulators()
    beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
    beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
    lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
    beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
    beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
    epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
    lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))

    # \\(m := beta1 * m + (1 - beta1) * g_t\\)
    m = self.get_slot(var, "m")
    m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad
    m_update_op = resource_variable_ops.resource_scatter_update(m.handle,
                                                                indices,
                                                                m_t_slice)

    # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
    v = self.get_slot(var, "v")
    v_t_slice = (beta2_t * array_ops.gather(v, indices) +
                 (1 - beta2_t) * math_ops.square(grad))
    v_update_op = resource_variable_ops.resource_scatter_update(v.handle,
                                                                indices,
                                                                v_t_slice)

    # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
    var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t)
    var_update_op = resource_variable_ops.resource_scatter_sub(var.handle,
                                                               indices,
                                                               var_slice)

    return control_flow_ops.group(var_update_op, m_update_op, v_update_op)
    def _resource_apply_sparse(self, grad, var, indices):
        beta1_power, beta2_power = self._get_beta_accumulators()
        beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
        beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
        beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
        epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
        global_step = self._get_step_accumulators()
        global_step = math_ops.cast(global_step, var.dtype.base_dtype)
        pre_step = self.get_slot(var, "pre_step")

        pre_step_slice = array_ops.gather(pre_step, indices)
        skipped_steps = global_step - pre_step_slice

        m = self.get_slot(var, "m")
        m_slice = array_ops.gather(m, indices)
        v = self.get_slot(var, "v")
        v_slice = array_ops.gather(v, indices)

        # for performance reason, here use math_ops.exp(a* math_ops.log(b)) to
        # replace math_ops.pow(b, a)
        # \\(lr : = extlearningrate * sqrt(1 - beta2 * * pre_step) /
        #           (1 - beta1 * * pre_step) *(1 - beta1 * * skipped_step) /
        #           (1 - beta1)\\)
        lr = ((lr_t * math_ops.sqrt(1 - math_ops.exp(pre_step_slice *
                                                     math_ops.log(beta2_t))) /
               (1 - math_ops.exp(pre_step_slice * math_ops.log(beta1_t)))) *
              (1 - math_ops.exp(math_ops.log(beta1_t) * skipped_steps)) /
              (1 - beta1_t))
        # \\(variable -= learning_rate * m /(epsilon + sqrt(v))\\)
        var_slice = lr * m_slice / (math_ops.sqrt(v_slice) + epsilon_t)
        var_update_op = resource_variable_ops.resource_scatter_sub(
            var.handle, indices, var_slice)

        with ops.control_dependencies([var_update_op]):
            # \\(m : = m * beta1 * * skipped_step +(1 - beta1) * g_t\\)
            # for performance reason, here use math_ops.exp(a* math_ops.log(b)) to
            # replace math_ops.pow(b, a)
            m_t_slice = (
                math_ops.exp(math_ops.log(beta1_t) * skipped_steps) * m_slice +
                (1 - beta1_t) * grad)
            m_update_op = resource_variable_ops.resource_scatter_update(
                m.handle, indices, m_t_slice)

            # \\(v : = v * beta2 * * skipped_step +(1 - beta2) *(g_t * g_t)\\)
            # for performance reason, here use math_ops.exp(a* math_ops.log(b)) to
            # replace math_ops.pow(b, a)
            v_t_slice = (
                math_ops.exp(math_ops.log(beta2_t) * skipped_steps) * v_slice +
                (1 - beta2_t) * math_ops.square(grad))
            v_update_op = resource_variable_ops.resource_scatter_update(
                v.handle, indices, v_t_slice)

        with ops.control_dependencies([m_update_op, v_update_op]):
            pre_step_update_op = resource_variable_ops.resource_scatter_update(
                pre_step.handle, indices, global_step)

        return control_flow_ops.group(var_update_op, m_update_op, v_update_op,
                                      pre_step_update_op)
 def testScatterSubScalar(self):
   handle = resource_variable_ops.var_handle_op(
       dtype=dtypes.int32, shape=[1, 1])
   self.evaluate(
       resource_variable_ops.assign_variable_op(
           handle, constant_op.constant([[1]], dtype=dtypes.int32)))
   self.evaluate(
       resource_variable_ops.resource_scatter_sub(
           handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
   read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
   self.assertEqual(self.evaluate(read), [[-1]])
 def testScatterSubScalar(self):
   handle = resource_variable_ops.var_handle_op(
       dtype=dtypes.int32, shape=[1, 1])
   self.evaluate(
       resource_variable_ops.assign_variable_op(
           handle, constant_op.constant([[1]], dtype=dtypes.int32)))
   self.evaluate(
       resource_variable_ops.resource_scatter_sub(
           handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
   read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
   self.assertEqual(self.evaluate(read), [[-1]])
示例#7
0
 def testScatterSub(self):
   with self.session() as sess, self.test_scope():
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.int32, shape=[2, 1])
     sess.run(
         resource_variable_ops.assign_variable_op(
             handle, constant_op.constant([[4], [1]], dtype=dtypes.int32)))
     sess.run(
         resource_variable_ops.resource_scatter_sub(
             handle, [1], constant_op.constant([[2]], dtype=dtypes.int32)))
     read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
     self.assertAllEqual(self.evaluate(read), [[4], [-1]])
 def testScatterSub(self):
   with self.test_session() as sess, self.test_scope():
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.int32, shape=[2, 1])
     sess.run(
         resource_variable_ops.assign_variable_op(
             handle, constant_op.constant([[4], [1]], dtype=dtypes.int32)))
     sess.run(
         resource_variable_ops.resource_scatter_sub(
             handle, [1], constant_op.constant([[2]], dtype=dtypes.int32)))
     read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
     self.assertAllEqual(self.evaluate(read), [[4], [-1]])
示例#9
0
 def _resource_scatter_sub(self, x, i, v):
   sub_op = resource_variable_ops.resource_scatter_sub(x.handle, i, v)
   with ops.control_dependencies([sub_op]):
     return x.value()