def get_updates(self, loss, params): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] self.weights = [self.iterations] lr = self.learning_rate for i, (p, g) in enumerate(zip(params, grads)): g2 = K.square(g) + self.epsilon1 shape, dtype = K.int_shape(p), K.dtype(p) factored_shape = self.factored_shape(shape) if factored_shape is None: # 定义参数 v = K.zeros(shape, dtype=dtype, name='v_' + str(i)) self.weights.append(v) # 定义更新 v_t = self.beta2 * v + (1.0 - self.beta2) * g2 self.updates.append(K.update(v, v_t)) else: # 定义参数 shape1, axis1, shape2, axis2 = factored_shape vr = K.zeros(shape1, dtype=dtype, name='vr_' + str(i)) vc = K.zeros(shape2, dtype=dtype, name='vc_' + str(i)) self.weights.extend([vr, vc]) # 定义更新 vr_t = self.beta2 * vr + K.mean(g2, axis=axis1, keepdims=True) vc_t = self.beta2 * vc + K.mean(g2, axis=axis2, keepdims=True) self.updates.extend([K.update(vr, vr_t), K.update(vc, vc_t)]) # 合成矩阵 v_t = vr_t * vc_t / K.mean(vr_t, axis=axis2, keepdims=True) # 增量主体 u = g / K.sqrt(v_t) # 增量裁剪 if self.clipping_threshold is not None: u_rms = K.mean(K.sum(K.square(u))) d = self.clipping_threshold u = u / K.maximum(1.0, u_rms / d) # 增量滑动 if self.beta1 > 0.0: # 定义参数 m = K.zeros(shape, dtype=dtype, name='m_' + str(i)) self.weights.append(m) # 定义更新 m_t = self.beta1 * m + (1.0 - self.beta1) * u self.updates.append(K.update(m, m_t)) u = m_t # 增量调整 if self.multiply_by_parameter_scale: u = u * K.maximum(K.mean(K.sum(K.square(p))), self.epsilon2) # 更新参数 self.updates.append(K.update(p, p - lr * u)) return self.updates
def compute_position_ids(self, inputs): """T5的相对位置分桶(直接翻译自官方T5源码) """ q, v = inputs # 计算位置差 q_idxs = K.arange(0, K.shape(q)[1], dtype='int32') q_idxs = K.expand_dims(q_idxs, 1) v_idxs = K.arange(0, K.shape(v)[1], dtype='int32') v_idxs = K.expand_dims(v_idxs, 0) pos_ids = v_idxs - q_idxs # 后处理操作 num_buckets, max_distance = self.input_dim, self.max_distance ret = 0 n = -pos_ids if self.bidirectional: num_buckets //= 2 ret += K.cast(K.less(n, 0), 'int32') * num_buckets n = K.abs(n) else: n = K.maximum(n, 0) # now n is in the range [0, inf) max_exact = num_buckets // 2 is_small = K.less(n, max_exact) val_if_large = max_exact + K.cast( K.log(K.cast(n, K.floatx()) / max_exact) / np.log(max_distance / max_exact) * (num_buckets - max_exact), 'int32', ) val_if_large = K.minimum(val_if_large, num_buckets - 1) ret += K.switch(is_small, n, val_if_large) return ret
def _resource_apply(self, grad, var, indices=None): lr = self.learning_rate g2 = K.square(grad) + self.epsilon1 shape = K.int_shape(var) factored_shape = self.factored_shape(shape) if factored_shape is None: v = self.get_slot(var, 'v') # 定义更新 v_t = self.beta2 * v + (1.0 - self.beta2) * g2 v_t = K.update(v, v_t) else: shape1, axis1, shape2, axis2 = factored_shape vr = self.get_slot(var, 'vr') vc = self.get_slot(var, 'vc') # 定义更新 vr_t = self.beta2 * vr + K.mean(g2, axis=axis1, keepdims=True) vc_t = self.beta2 * vc + K.mean(g2, axis=axis2, keepdims=True) vr_t, vc_t = K.update(vr, vr_t), K.update(vc, vc_t) # 合成矩阵 v_t = vr_t * vc_t / K.mean(vr_t, axis=axis2, keepdims=True) # 增量主体 u = grad / K.sqrt(v_t) # 增量裁剪 if self.clipping_threshold is not None: u_rms = K.mean(K.sum(K.square(u))) d = self.clipping_threshold u = u / K.maximum(1.0, u_rms / d) # 增量滑动 if self.beta1 > 0.0: m = self.get_slot(var, 'm') # 定义更新 m_t = self.beta1 * m + (1.0 - self.beta1) * u u = K.update(m, m_t) # 增量调整 if self.multiply_by_parameter_scale: u = u * K.maximum(K.mean(K.sum(K.square(var))), self.epsilon2) # 更新参数 return K.update(var, var - lr * u)