예제 #1
0
 def new_update(x, new_x):
     if is_one_of(x, params) and self._do_layer_adaptation(x):
         dx = new_x - x
         lr_t = K.clip(self.learning_rate, K.epsilon(), 1e10)
         x_norm = tf.norm(x)
         g_norm = tf.norm(dx / lr_t)
         ratio = K.switch(
             x_norm > 0.,
             K.switch(g_norm > K.epsilon(), x_norm / g_norm, 1.),
             1.)
         new_x = x + dx * ratio
     return old_update(x, new_x)
예제 #2
0
 def new_update(x, new_x):
     if x is var and self._do_layer_adaptation(x):
         dx = new_x - x
         lr_t = self._decayed_lr(x.dtype.base_dtype)
         lr_t = K.clip(lr_t, K.epsilon(), 1e10)
         x_norm = tf.norm(x)
         g_norm = tf.norm(dx / lr_t)
         ratio = K.switch(
             x_norm > 0.,
             K.switch(g_norm > K.epsilon(), x_norm / g_norm, 1.),
             1.)
         new_x = x + dx * ratio
     return old_update(x, new_x)
예제 #3
0
 def compute_position_ids(self, inputs):
     """T5的相对位置分桶(直接翻译自官方T5源码)
     """
     q, v = inputs
     # 计算位置差
     q_idxs = K.arange(0, K.shape(q)[1], dtype='int32')
     q_idxs = K.expand_dims(q_idxs, 1)
     v_idxs = K.arange(0, K.shape(v)[1], dtype='int32')
     v_idxs = K.expand_dims(v_idxs, 0)
     pos_ids = v_idxs - q_idxs
     # 后处理操作
     num_buckets, max_distance = self.input_dim, self.max_distance
     ret = 0
     n = -pos_ids
     if self.bidirectional:
         num_buckets //= 2
         ret += K.cast(K.less(n, 0), 'int32') * num_buckets
         n = K.abs(n)
     else:
         n = K.maximum(n, 0)
     # now n is in the range [0, inf)
     max_exact = num_buckets // 2
     is_small = K.less(n, max_exact)
     val_if_large = max_exact + K.cast(
         K.log(K.cast(n, K.floatx()) / max_exact) /
         np.log(max_distance / max_exact) * (num_buckets - max_exact),
         'int32',
     )
     val_if_large = K.minimum(val_if_large, num_buckets - 1)
     ret += K.switch(is_small, n, val_if_large)
     return ret
예제 #4
0
        def _resource_apply_op(self, grad, var, indices=None):
            op = super(new_optimizer,
                       self)._resource_apply_op(grad, var, indices)

            k, alpha = self.steps_per_slow_update, self.slow_step_size
            cond = K.equal(self.iterations % k, 0)
            slow_var = self.get_slot(var, 'slow_var')
            slow_var_t = slow_var + alpha * (var - slow_var)

            with tf.control_dependencies([op]):
                slow_update = K.update(slow_var,
                                       K.switch(cond, slow_var_t, slow_var))
                with tf.control_dependencies([slow_update]):
                    copy_update = K.update(var, K.switch(cond, slow_var, var))

            return copy_update
예제 #5
0
        def _resource_apply_op(self, grad, var, indices=None):
            # 更新判据
            cond = K.equal(self.iterations % self.grad_accum_steps, 0)
            # 获取梯度
            ag = self.get_slot(var, 'ag')

            old_update = K.update

            def new_update(x, new_x):
                new_x = K.switch(cond, new_x, x)
                return old_update(x, new_x)

            K.update = new_update
            ag_t = ag / self.grad_accum_steps
            op = super(new_optimizer, self)._resource_apply_op(ag_t, var)
            K.update = old_update

            # 累积梯度
            with tf.control_dependencies([op]):
                ag_t = K.switch(cond, K.zeros_like(ag), ag)
                with tf.control_dependencies([K.update(ag, ag_t)]):
                    if indices is None:
                        ag_t = K.update(ag, ag + grad)
                    else:
                        ag_t = self._resource_scatter_add(ag, indices, grad)

            return ag_t
예제 #6
0
        def get_updates(self, loss, params):
            updates = super(new_optimizer, self).get_updates(loss, params)

            k, alpha = self.steps_per_slow_update, self.slow_step_size
            cond = K.equal(self.iterations % k, 0)
            slow_vars = [
                K.zeros(K.int_shape(p),
                        dtype=K.dtype(p),
                        name='slow_var_%s' % i) for i, p in enumerate(params)
            ]

            with tf.control_dependencies(updates):
                slow_updates = [
                    K.update(q, K.switch(cond, q + alpha * (p - q), q))
                    for p, q in zip(params, slow_vars)
                ]
                with tf.control_dependencies(slow_updates):
                    copy_updates = [
                        K.update(p, K.switch(cond, q, p))
                        for p, q in zip(params, slow_vars)
                    ]

            return copy_updates
예제 #7
0
 def new_update(x, new_x):
     new_x = K.switch(cond, new_x, x)
     return old_update(x, new_x)