예제 #1
0
 def _split_indexed_slices_v2(sp_input=None,
                              num_split=None,
                              dim_size=0,
                              name=None):
     ids_per_partition = dim_size // num_split
     extras = dim_size % num_split
     with ops.name_scope(name):
         # When the partitioned dim cannot be divided by num_split, the reminders are
         # evenly assigned from the first partition to the last.
         p_assignments = math_ops.maximum(
             sp_input.indices // (ids_per_partition + 1),
             (sp_input.indices - extras) // ids_per_partition)
         split_grads = []
         for i in range(0, num_split):
             with ops.name_scope(f"part_{i}"):
                 ids_not_in_i = array_ops.where(
                     math_ops.not_equal(p_assignments, i))
                 flat_ids_not_in_i = array_ops.reshape(ids_not_in_i, [-1])
                 if sp_input.indices.dtype == dtypes.int64:
                     flat_ids_not_in_i = math_ops.cast(
                         flat_ids_not_in_i, dtypes.int64)
                 else:
                     flat_ids_not_in_i = math_ops.cast(
                         flat_ids_not_in_i, dtypes.int32)
                 s = array_ops.sparse_mask(sp_input, flat_ids_not_in_i)
                 if i < extras:
                     s._indices = math_ops.floor_mod(
                         s.indices, ids_per_partition + 1)
                 else:
                     s._indices = math_ops.floor_mod(
                         s.indices - extras, ids_per_partition)
             split_grads.append(s)
     return split_grads
 def _stabilize(self, var):
     if math_ops.floor_mod(self.iterations, self.stabilize) == 0:
         manifold = get_manifold(var)
         var.assign(manifold.projx(var))
         if self._momentum:
             momentum = self.get_slot(var, "momentum")
             momentum.assign(manifold.proju(var, momentum))
 def testConsistent(self):
   nums, divs = self.intTestData()
   with self.test_session():
     tf_result = (
         math_ops.floor_div(nums, divs) * divs + math_ops.floor_mod(nums, divs)
     ).eval()
     tf_nums = array_ops.constant(nums)
     tf_divs = array_ops.constant(divs)
     tf2_result = (tf_nums // tf_divs * tf_divs + tf_nums % tf_divs).eval()
     np_result = (nums // divs) * divs + (nums % divs)
     # consistentcy with numpy
     self.assertAllEqual(tf_result, np_result)
     # consistentcy with two forms of divide
     self.assertAllEqual(tf_result, tf2_result)
     # consistency for truncation form
     tf3_result = (
         math_ops.truncatediv(nums, divs) * divs
         + math_ops.truncatemod(nums, divs)
     ).eval()
     expanded_nums = np.reshape(np.tile(nums, divs.shape[1]),
                                (nums.shape[0], divs.shape[1]))
     # Consistent with desire to get numerator
     self.assertAllEqual(tf3_result, expanded_nums)
     # Consistent with desire to get numerator
     self.assertAllEqual(tf_result, expanded_nums)
예제 #4
0
    def _apply_sparse_shared(self, grad, var, indices, scatter_add):
        slow_weight = self.get_slot(var, 'slow')
        alpha = math_ops.cast(self._alpha_t, var.dtype.base_dtype)

        step, beta1_power, beta2_power = self._get_beta_accumulators()
        beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
        beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)

        beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
        beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
        epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)

        sma_inf = 2.0 / (1.0 - beta2_t) - 1.0
        sma_t = sma_inf - 2.0 * step * beta2_power / (1.0 - beta2_power)

        m = self.get_slot(var, "m")
        m_scaled_g_values = grad * (1 - beta1_t)
        m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking)

        with ops.control_dependencies([m_t]):
            m_t = scatter_add(m, indices, m_scaled_g_values)

        mhat_t = m_t / (1.0 - beta1_power)

        v = self.get_slot(var, "v")
        v_scaled_g_values = (grad * grad) * (1 - beta2_t)
        v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)

        with ops.control_dependencies([v_t]):
            v_t = scatter_add(v, indices, v_scaled_g_values)

        vhat_t = math_ops.sqrt(v_t / (1.0 - beta2_power) + epsilon_t)

        r_t = math_ops.sqrt(((sma_t - 4.0) * (sma_t - 2.0) * sma_inf) /
                            ((sma_inf - 4.0) * (sma_inf - 2.0) * sma_t))

        var_t = tf.cond(sma_t >= self._sma_thtrshhold,
                        lambda: r_t * mhat_t / vhat_t, lambda: mhat_t)

        if self._weight_decay > 0.0:
            var_t += math_ops.cast(self._weight_decay_t,
                                   var.dtype.base_dtype) * var

        var_update = state_ops.assign_sub(var,
                                          lr_t * var_t,
                                          use_locking=self._use_locking)

        # var_temp = var_update.copy()
        var_update = tf.cond(
            math_ops.equal(math_ops.floor_mod(step, self._k_t), 0),
            lambda: state_ops.assign(
                var_update,
                state_ops.assign_add(slow_weight,
                                     (var_update - slow_weight) * alpha)),
            lambda: var_update)

        updates = [var_update, m_t, v_t]

        return control_flow_ops.group(*updates)
예제 #5
0
 def testConsistent(self):
   nums, divs = self.intTestData()
   with self.test_session():
     tf_result = (
         math_ops.floor_div(nums, divs) * divs + math_ops.floor_mod(nums, divs)
     ).eval()
     tf_nums = array_ops.constant(nums)
     tf_divs = array_ops.constant(divs)
     tf2_result = (tf_nums // tf_divs * tf_divs + tf_nums % tf_divs).eval()
     np_result = (nums // divs) * divs + (nums % divs)
     # consistentcy with numpy
     self.assertAllEqual(tf_result, np_result)
     # consistentcy with two forms of divide
     self.assertAllEqual(tf_result, tf2_result)
     # consistency for truncation form
     tf3_result = (
         math_ops.truncatediv(nums, divs) * divs
         + math_ops.truncatemod(nums, divs)
     ).eval()
     expanded_nums = np.reshape(np.tile(nums, divs.shape[1]),
                                (nums.shape[0], divs.shape[1]))
     # Consistent with desire to get numerator
     self.assertAllEqual(tf3_result, expanded_nums)
     # Consistent with desire to get numerator
     self.assertAllEqual(tf_result, expanded_nums)
예제 #6
0
    def _resource_apply_dense(self, grad, var):
        slow_weight = self.get_slot(var, 'slow')
        alpha = math_ops.cast(self._alpha_t, var.dtype.base_dtype)

        step, beta1_power, beta2_power = self._get_beta_accumulators()
        beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
        beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)

        beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
        beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
        epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)

        sma_inf = 2.0 / (1.0 - beta2_t) - 1.0

        sma_t = sma_inf - 2.0 * step * beta2_power / (1.0 - beta2_power)

        m = self.get_slot(var, "m")
        m_t = state_ops.assign(m,
                               beta1_t * m + (1.0 - beta1_t) * grad,
                               use_locking=self._use_locking)
        mhat_t = m_t / (1.0 - beta1_power)

        v = self.get_slot(var, "v")
        v_t = state_ops.assign(v,
                               beta2_t * v +
                               (1.0 - beta2_t) * math_ops.square(grad),
                               use_locking=self._use_locking)
        vhat_t = math_ops.sqrt(v_t / ((1.0 - beta2_power) + epsilon_t))

        r_t = math_ops.sqrt(((sma_t - 4.0) * (sma_t - 2.0) * sma_inf) /
                            ((sma_inf - 4.0) * (sma_inf - 2.0) * sma_t))

        var_t = tf.cond(sma_t >= self._sma_thtrshhold,
                        lambda: r_t * mhat_t / vhat_t, lambda: mhat_t)

        if self._weight_decay > 0.0:
            var_t += math_ops.cast(self._weight_decay_t,
                                   var.dtype.base_dtype) * var

        var_update = state_ops.assign_sub(var,
                                          lr_t * var_t,
                                          use_locking=self._use_locking)

        # var_temp = var_update.copy()
        var_update = tf.cond(
            math_ops.equal(math_ops.floor_mod(step, self._k_t), 0),
            lambda: state_ops.assign(
                var_update,
                state_ops.assign_add(slow_weight,
                                     (var_update - slow_weight) * alpha)),
            lambda: var_update)

        # print("var_updata : ",var_update)
        updates = [var_update, m_t, v_t]

        return control_flow_ops.group(*updates)
예제 #7
0
 def testConsistent(self):
   nums, divs = self.intTestData()
   with self.test_session():
     tf_result = (
         math_ops.floor_div(nums, divs) * divs + math_ops.floor_mod(nums, divs)
     ).eval()
     tf_nums = array_ops.constant(nums)
     tf_divs = array_ops.constant(divs)
     tf2_result = (tf_nums // tf_divs * tf_divs + tf_nums % tf_divs).eval()
     np_result = (nums // divs) * divs + (nums % divs)
     self.assertAllEqual(tf_result, np_result)
     self.assertAllEqual(tf_result, tf2_result)
예제 #8
0
 def testConsistent(self):
   nums, divs = self.intTestData()
   with self.test_session():
     tf_result = (
         math_ops.floor_div(nums, divs) * divs + math_ops.floor_mod(nums, divs)
     ).eval()
     tf_nums = array_ops.constant(nums)
     tf_divs = array_ops.constant(divs)
     tf2_result = (tf_nums // tf_divs * tf_divs + tf_nums % tf_divs).eval()
     np_result = (nums // divs) * divs + (nums % divs)
     self.assertAllEqual(tf_result, np_result)
     self.assertAllEqual(tf_result, tf2_result)
예제 #9
0
 def testFloorModFloat(self):
   nums, divs = self.floatTestData()
   with self.test_session():
     tf_result = math_ops.floor_mod(nums, divs).eval()
     np_result = nums % divs
     self.assertAllEqual(tf_result, np_result)
예제 #10
0
 def testFloorModFloat(self):
   nums, divs = self.floatTestData()
   with self.test_session():
     tf_result = math_ops.floor_mod(nums, divs).eval()
     np_result = nums % divs
     self.assertAllEqual(tf_result, np_result)
예제 #11
0
 def _stabilize(self, var):
     if math_ops.floor_mod(self.iterations, self.stabilize) == 0:
         manifold = get_manifold(var)
         m = self.get_slot(var, "m")
         var.assign(manifold.projx(var))
         m.assign(manifold.proju(var, m))