def _resource_apply_sparse(self, grad, var, indices): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) beta_1_t = self._get_hyper('beta_1', var_dtype) beta_2_t = self._get_hyper('beta_2', var_dtype) local_step = math_ops.cast(self.iterations + 1, var_dtype) beta_1_power = math_ops.pow(beta_1_t, local_step) beta_2_power = math_ops.pow(beta_2_t, local_step) epsilon_t = self._get_hyper('epsilon', var_dtype) lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)) # \\(m := beta1 * m + (1 - beta1) * g_t\\) m = self.get_slot(var, "m") m_t_slice = beta_1_t * array_ops.gather( m, indices) + (1 - beta_1_t) * grad m_update_op = resource_variable_ops.resource_scatter_update( m.handle, indices, m_t_slice) # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) v = self.get_slot(var, "v") v_t_slice = (beta_2_t * array_ops.gather(v, indices) + (1 - beta_2_t) * math_ops.square(grad)) v_update_op = resource_variable_ops.resource_scatter_update( v.handle, indices, v_t_slice) # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t) var_update_op = resource_variable_ops.resource_scatter_sub( var.handle, indices, var_slice) return control_flow_ops.group( *[var_update_op, m_update_op, v_update_op])
def _resource_apply_sparse(self, grad, var, indices): beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) # \\(m := beta1 * m + (1 - beta1) * g_t\\) m = self.get_slot(var, "m") m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad m_update_op = resource_variable_ops.resource_scatter_update(m.handle, indices, m_t_slice) # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) v = self.get_slot(var, "v") v_t_slice = (beta2_t * array_ops.gather(v, indices) + (1 - beta2_t) * math_ops.square(grad)) v_update_op = resource_variable_ops.resource_scatter_update(v.handle, indices, v_t_slice) # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t) var_update_op = resource_variable_ops.resource_scatter_sub(var.handle, indices, var_slice) return control_flow_ops.group(var_update_op, m_update_op, v_update_op)
def _resource_apply_sparse(self, grad, var, indices): beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) global_step = self._get_step_accumulators() global_step = math_ops.cast(global_step, var.dtype.base_dtype) pre_step = self.get_slot(var, "pre_step") pre_step_slice = array_ops.gather(pre_step, indices) skipped_steps = global_step - pre_step_slice m = self.get_slot(var, "m") m_slice = array_ops.gather(m, indices) v = self.get_slot(var, "v") v_slice = array_ops.gather(v, indices) # for performance reason, here use math_ops.exp(a* math_ops.log(b)) to # replace math_ops.pow(b, a) # \\(lr : = extlearningrate * sqrt(1 - beta2 * * pre_step) / # (1 - beta1 * * pre_step) *(1 - beta1 * * skipped_step) / # (1 - beta1)\\) lr = ((lr_t * math_ops.sqrt(1 - math_ops.exp(pre_step_slice * math_ops.log(beta2_t))) / (1 - math_ops.exp(pre_step_slice * math_ops.log(beta1_t)))) * (1 - math_ops.exp(math_ops.log(beta1_t) * skipped_steps)) / (1 - beta1_t)) # \\(variable -= learning_rate * m /(epsilon + sqrt(v))\\) var_slice = lr * m_slice / (math_ops.sqrt(v_slice) + epsilon_t) var_update_op = resource_variable_ops.resource_scatter_sub( var.handle, indices, var_slice) with ops.control_dependencies([var_update_op]): # \\(m : = m * beta1 * * skipped_step +(1 - beta1) * g_t\\) # for performance reason, here use math_ops.exp(a* math_ops.log(b)) to # replace math_ops.pow(b, a) m_t_slice = ( math_ops.exp(math_ops.log(beta1_t) * skipped_steps) * m_slice + (1 - beta1_t) * grad) m_update_op = resource_variable_ops.resource_scatter_update( m.handle, indices, m_t_slice) # \\(v : = v * beta2 * * skipped_step +(1 - beta2) *(g_t * g_t)\\) # for performance reason, here use math_ops.exp(a* math_ops.log(b)) to # replace math_ops.pow(b, a) v_t_slice = ( math_ops.exp(math_ops.log(beta2_t) * skipped_steps) * v_slice + (1 - beta2_t) * math_ops.square(grad)) v_update_op = resource_variable_ops.resource_scatter_update( v.handle, indices, v_t_slice) with ops.control_dependencies([m_update_op, v_update_op]): pre_step_update_op = resource_variable_ops.resource_scatter_update( pre_step.handle, indices, global_step) return control_flow_ops.group(var_update_op, m_update_op, v_update_op, pre_step_update_op)
def testScatterSubScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) self.evaluate( resource_variable_ops.assign_variable_op( handle, constant_op.constant([[1]], dtype=dtypes.int32))) self.evaluate( resource_variable_ops.resource_scatter_sub( handle, [0], constant_op.constant(2, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[-1]])
def testScatterSub(self): with self.session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[2, 1]) sess.run( resource_variable_ops.assign_variable_op( handle, constant_op.constant([[4], [1]], dtype=dtypes.int32))) sess.run( resource_variable_ops.resource_scatter_sub( handle, [1], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertAllEqual(self.evaluate(read), [[4], [-1]])
def testScatterSub(self): with self.test_session() as sess, self.test_scope(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[2, 1]) sess.run( resource_variable_ops.assign_variable_op( handle, constant_op.constant([[4], [1]], dtype=dtypes.int32))) sess.run( resource_variable_ops.resource_scatter_sub( handle, [1], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertAllEqual(self.evaluate(read), [[4], [-1]])
def _resource_scatter_sub(self, x, i, v): sub_op = resource_variable_ops.resource_scatter_sub(x.handle, i, v) with ops.control_dependencies([sub_op]): return x.value()