def compute_position_ids(self, inputs): """T5的相对位置分桶(直接翻译自官方T5源码) """ q, v = inputs # 计算位置差 q_idxs = K.arange(0, K.shape(q)[1], dtype='int32') q_idxs = K.expand_dims(q_idxs, 1) v_idxs = K.arange(0, K.shape(v)[1], dtype='int32') v_idxs = K.expand_dims(v_idxs, 0) pos_ids = v_idxs - q_idxs # 后处理操作 num_buckets, max_distance = self.input_dim, self.max_distance ret = 0 n = -pos_ids if self.bidirectional: num_buckets //= 2 ret += K.cast(K.less(n, 0), 'int32') * num_buckets n = K.abs(n) else: n = K.maximum(n, 0) # now n is in the range [0, inf) max_exact = num_buckets // 2 is_small = K.less(n, max_exact) val_if_large = max_exact + K.cast( K.log(K.cast(n, K.floatx()) / max_exact) / np.log(max_distance / max_exact) * (num_buckets - max_exact), 'int32', ) val_if_large = K.minimum(val_if_large, num_buckets - 1) ret += K.switch(is_small, n, val_if_large) return ret
def learning_rate(self): if self._learning_rate is None: iterations = K.cast(self.iterations + 1, K.floatx()) learning_rate = K.minimum(1.0 / K.sqrt(iterations), 0.01) if self.multiply_by_parameter_scale: return learning_rate else: return learning_rate * 0.05 else: if not hasattr(self, '__learning_rate'): with K.name_scope(self.__class__.__name__): self.__learning_rate = K.variable(self._learning_rate, name='learning_rate') return self.__learning_rate