def _split_indexed_slices_v2(sp_input=None, num_split=None, dim_size=0, name=None): ids_per_partition = dim_size // num_split extras = dim_size % num_split with ops.name_scope(name): # When the partitioned dim cannot be divided by num_split, the reminders are # evenly assigned from the first partition to the last. p_assignments = math_ops.maximum( sp_input.indices // (ids_per_partition + 1), (sp_input.indices - extras) // ids_per_partition) split_grads = [] for i in range(0, num_split): with ops.name_scope(f"part_{i}"): ids_not_in_i = array_ops.where( math_ops.not_equal(p_assignments, i)) flat_ids_not_in_i = array_ops.reshape(ids_not_in_i, [-1]) if sp_input.indices.dtype == dtypes.int64: flat_ids_not_in_i = math_ops.cast( flat_ids_not_in_i, dtypes.int64) else: flat_ids_not_in_i = math_ops.cast( flat_ids_not_in_i, dtypes.int32) s = array_ops.sparse_mask(sp_input, flat_ids_not_in_i) if i < extras: s._indices = math_ops.floor_mod( s.indices, ids_per_partition + 1) else: s._indices = math_ops.floor_mod( s.indices - extras, ids_per_partition) split_grads.append(s) return split_grads
def _stabilize(self, var): if math_ops.floor_mod(self.iterations, self.stabilize) == 0: manifold = get_manifold(var) var.assign(manifold.projx(var)) if self._momentum: momentum = self.get_slot(var, "momentum") momentum.assign(manifold.proju(var, momentum))
def testConsistent(self): nums, divs = self.intTestData() with self.test_session(): tf_result = ( math_ops.floor_div(nums, divs) * divs + math_ops.floor_mod(nums, divs) ).eval() tf_nums = array_ops.constant(nums) tf_divs = array_ops.constant(divs) tf2_result = (tf_nums // tf_divs * tf_divs + tf_nums % tf_divs).eval() np_result = (nums // divs) * divs + (nums % divs) # consistentcy with numpy self.assertAllEqual(tf_result, np_result) # consistentcy with two forms of divide self.assertAllEqual(tf_result, tf2_result) # consistency for truncation form tf3_result = ( math_ops.truncatediv(nums, divs) * divs + math_ops.truncatemod(nums, divs) ).eval() expanded_nums = np.reshape(np.tile(nums, divs.shape[1]), (nums.shape[0], divs.shape[1])) # Consistent with desire to get numerator self.assertAllEqual(tf3_result, expanded_nums) # Consistent with desire to get numerator self.assertAllEqual(tf_result, expanded_nums)
def _apply_sparse_shared(self, grad, var, indices, scatter_add): slow_weight = self.get_slot(var, 'slow') alpha = math_ops.cast(self._alpha_t, var.dtype.base_dtype) step, beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) sma_inf = 2.0 / (1.0 - beta2_t) - 1.0 sma_t = sma_inf - 2.0 * step * beta2_power / (1.0 - beta2_power) m = self.get_slot(var, "m") m_scaled_g_values = grad * (1 - beta1_t) m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking) with ops.control_dependencies([m_t]): m_t = scatter_add(m, indices, m_scaled_g_values) mhat_t = m_t / (1.0 - beta1_power) v = self.get_slot(var, "v") v_scaled_g_values = (grad * grad) * (1 - beta2_t) v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) with ops.control_dependencies([v_t]): v_t = scatter_add(v, indices, v_scaled_g_values) vhat_t = math_ops.sqrt(v_t / (1.0 - beta2_power) + epsilon_t) r_t = math_ops.sqrt(((sma_t - 4.0) * (sma_t - 2.0) * sma_inf) / ((sma_inf - 4.0) * (sma_inf - 2.0) * sma_t)) var_t = tf.cond(sma_t >= self._sma_thtrshhold, lambda: r_t * mhat_t / vhat_t, lambda: mhat_t) if self._weight_decay > 0.0: var_t += math_ops.cast(self._weight_decay_t, var.dtype.base_dtype) * var var_update = state_ops.assign_sub(var, lr_t * var_t, use_locking=self._use_locking) # var_temp = var_update.copy() var_update = tf.cond( math_ops.equal(math_ops.floor_mod(step, self._k_t), 0), lambda: state_ops.assign( var_update, state_ops.assign_add(slow_weight, (var_update - slow_weight) * alpha)), lambda: var_update) updates = [var_update, m_t, v_t] return control_flow_ops.group(*updates)
def _resource_apply_dense(self, grad, var): slow_weight = self.get_slot(var, 'slow') alpha = math_ops.cast(self._alpha_t, var.dtype.base_dtype) step, beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) sma_inf = 2.0 / (1.0 - beta2_t) - 1.0 sma_t = sma_inf - 2.0 * step * beta2_power / (1.0 - beta2_power) m = self.get_slot(var, "m") m_t = state_ops.assign(m, beta1_t * m + (1.0 - beta1_t) * grad, use_locking=self._use_locking) mhat_t = m_t / (1.0 - beta1_power) v = self.get_slot(var, "v") v_t = state_ops.assign(v, beta2_t * v + (1.0 - beta2_t) * math_ops.square(grad), use_locking=self._use_locking) vhat_t = math_ops.sqrt(v_t / ((1.0 - beta2_power) + epsilon_t)) r_t = math_ops.sqrt(((sma_t - 4.0) * (sma_t - 2.0) * sma_inf) / ((sma_inf - 4.0) * (sma_inf - 2.0) * sma_t)) var_t = tf.cond(sma_t >= self._sma_thtrshhold, lambda: r_t * mhat_t / vhat_t, lambda: mhat_t) if self._weight_decay > 0.0: var_t += math_ops.cast(self._weight_decay_t, var.dtype.base_dtype) * var var_update = state_ops.assign_sub(var, lr_t * var_t, use_locking=self._use_locking) # var_temp = var_update.copy() var_update = tf.cond( math_ops.equal(math_ops.floor_mod(step, self._k_t), 0), lambda: state_ops.assign( var_update, state_ops.assign_add(slow_weight, (var_update - slow_weight) * alpha)), lambda: var_update) # print("var_updata : ",var_update) updates = [var_update, m_t, v_t] return control_flow_ops.group(*updates)
def testConsistent(self): nums, divs = self.intTestData() with self.test_session(): tf_result = ( math_ops.floor_div(nums, divs) * divs + math_ops.floor_mod(nums, divs) ).eval() tf_nums = array_ops.constant(nums) tf_divs = array_ops.constant(divs) tf2_result = (tf_nums // tf_divs * tf_divs + tf_nums % tf_divs).eval() np_result = (nums // divs) * divs + (nums % divs) self.assertAllEqual(tf_result, np_result) self.assertAllEqual(tf_result, tf2_result)
def testFloorModFloat(self): nums, divs = self.floatTestData() with self.test_session(): tf_result = math_ops.floor_mod(nums, divs).eval() np_result = nums % divs self.assertAllEqual(tf_result, np_result)
def _stabilize(self, var): if math_ops.floor_mod(self.iterations, self.stabilize) == 0: manifold = get_manifold(var) m = self.get_slot(var, "m") var.assign(manifold.projx(var)) m.assign(manifold.proju(var, m))