def compute_position_ids(self, inputs): """T5的相对位置分桶(直接翻译自官方T5源码) i-i: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14... f(i-j):0 1 2 3 4 5 6 7 8 8 8 8 9 9 9 ... """ q, v = inputs # 计算位置差 q_idxs = K.arange(0, K.shape(q)[1], dtype='int32') q_idxs = K.expand_dims(q_idxs, 1) v_idxs = K.arange(0, K.shape(v)[1], dtype='int32') v_idxs = K.expand_dims(v_idxs, 0) pos_ids = v_idxs - q_idxs # 后处理操作 num_buckets, max_distance = self.input_dim, self.max_distance ret = 0 n = -pos_ids if self.bidirectional: num_buckets //= 2 ret += K.cast(K.less(n, 0), 'int32') * num_buckets n = K.abs(n) else: n = K.maximum(n, 0) # now n is in the range [0, inf) max_exact = num_buckets // 2 is_small = K.less(n, max_exact) val_if_large = max_exact + K.cast( K.log(K.cast(n, K.floatx()) / max_exact) / np.log(max_distance / max_exact) * (num_buckets - max_exact), 'int32', ) val_if_large = K.minimum(val_if_large, num_buckets - 1) ret += K.switch(is_small, n, val_if_large) return ret
def call(self, inputs): # PE_2i(p) = sin(p/10000^(2i/d_pos)) # PE_2i+1(p) = cos(p/10000^(2i/d_pos)) batch_size, seq_len, word_emb_dim = K.shape(inputs)[0], K.shape( inputs)[1], K.shape(inputs)[2] if not self.embedding_dim or self.method == 'add': self.embedding_dim = word_emb_dim t = 2 * K.arange(self.embedding_dim / 2, dtype='float32') / K.cast( self.embedding_dim, dtype='float32') embedding_wise_pos = 1. / K.pow( 10000., t) # 1/10000 ^(2i/d_pos) , shape = (p_dim/2, ) embedding_wise_pos = K.expand_dims(embedding_wise_pos, 0) # (1, p_dim/2) word_wise_pos = K.cumsum(K.ones_like(inputs[:, :, 0]), axis=1) # shape = [batch_size, seq_len] word_wise_pos = K.expand_dims(word_wise_pos, 2) # (batch_size, seq_len, 1) position_embedding = K.dot( word_wise_pos, embedding_wise_pos) # (batch_size, seq_len, p_dim/2) position_embedding = K.expand_dims(position_embedding, 3) position_embedding = K.reshape(K.concatenate( [K.sin(position_embedding), K.cos(position_embedding)], axis=-1), shape=(batch_size, seq_len, -1)) if self.method == 'add': return inputs + position_embedding return K.concatenate([inputs, position_embedding], axis=-1)
def log_norm_step(self, inputs, states): """递归求解归一化因子""" inputs, mask = inputs[:, :-1], inputs[:, -1:] states = K.expand_dims(states[0], 2) # batch_size,output_dim, 1 trans = K.expand_dims(self.trans, 0) # 1, output_dim, output_dim outputs = K.logsumexp(states + trans, 1) outputs += inputs outputs = mask * outputs + (1 - mask) * states[:, :, 0] return outputs, [outputs]
def compute_position_idx(self, inputs): q, v = inputs q_idx = K.arange(0, K.shape(q)[1], dtype='int32') q_idx = K.expand_dims(q_idx, 1) v_idx = K.arange(0, K.shape(v)[1], dtype='int32') v_idx = K.expand_dims(v_idx, 0) # 相对位置差 position_idx = v_idx - q_idx max_position = (self.input_dim - 1) // 2 position_idx = K.clip(position_idx, -max_position, max_position) position_idx = position_idx + max_position return position_idx
def call(self, inputs): """ conditional 时, condition 放在inputs后面,[inputs, condition] """ if self.conditional: inputs, cond = inputs if self.condition_hidden_units is not None: cond = self.condition_hidden_dense(cond) # 适配cond维度,与inputs保持一致 for _ in range(K.ndim(inputs) - K.ndim(cond)): cond = K.expand_dims(cond, 1) if self.center: beta = self.beta_dense(cond) + self.beta if self.scale: gamma = self.gamma_dense(cond) + self.gamma else: beta = self.beta gamma = self.gamma output = inputs if self.center: mean = K.mean(inputs, axis=-1, keepdims=True) output = output - mean if self.scale: var = K.mean(K.square(output), axis=-1, keepdims=True) std = K.sqrt(var + self.epsilon) output = output / std output = output * gamma if self.center: output = output + beta return output
def compute_mask(self, inputs, mask=None): if self.conditional: masks = mask if mask is not None else [] masks = [K.expand_dims(m, 0) for m in masks if m is not None] if len(masks) == 0: return None else: return K.all(K.concatenate(masks, axis=0), axis=0) return mask
def call(self, x, mask=None): x0 = x x = self.k_dense(x0) x = self.o_dense(x) if mask is not None: mask = K.cast(mask, K.floatx()) mask = K.expand_dims(mask, 2) x = x - (1 - mask) * 1e12 x = K.softmax(x, 1) x = K.sum(x0 * x, 1) return x
def call(self, inputs): input_shape = K.shape(inputs) batch_size, seq_length = input_shape[0], input_shape[1] pos_embedding = self.embeddings[:seq_length] pos_embedding = K.expand_dims(pos_embedding, 0) if self.merge_mode != 'add': pos_embedding = K.tile(pos_embedding, [batch_size, 1, 1]) if self.merge_mode == 'add': return inputs + pos_embedding return K.concatenate([inputs, pos_embedding], axis=-1)
def call(self, x, mask=None): x0 = x if mask is not None: mask = K.cast(mask, K.floatx()) mask = K.expand_dims(mask, 2) # x = x0 * mask if mask is not None else x0 x0 = Lambda(lambda x_: x_, output_shape=lambda s: s)(x0) # drop mask so do not put mask to conv1d x = self.conv1d(x0) x, g = x[:, :, :self.o_dim], x[:, :, self.o_dim:] if self.dropout_rate is not None: g = K.in_train_phase(K.dropout(g, self.dropout_rate), g) g = K.sigmoid(g) # mask is none mask = mask if mask is not None else K.ones_like(x) if self.skip_connection: if K.int_shape(x0)[-1] != self.o_dim: x0 = self.conv1d_1x1(x0) return (x0 * (1 - g) + x * g) * mask return x * g * mask
def call(self, inputs, **kwargs): logits, token_seq = inputs[:2] seq_shape = K.shape(token_seq) batch_size, seq_length = seq_shape[0], seq_shape[1] if self.pad_token_id is None: sequence_lengths = [seq_length - 1] * batch_size else: sequence_lengths = ( K.sum( K.cast( K.not_equal(token_seq, self.pad_token_id), dtype='int32', ), -1, keepdims=False, ) - 1 ) # only tf2 # return tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) indices = K.expand_dims(sequence_lengths, -1) return tf.gather_nd(logits, indices, batch_dims=1)
def call(self, inputs, mask=None): if mask is not None: mask = K.cast(mask, K.floatx()) mask = K.expand_dims(mask, 2) inputs = inputs - (1.0 - mask) * 1e12 return K.softmax(inputs, 1)
def call(self, x): seq, vec = x vec = K.expand_dims(vec, 1) vec = K.tile(vec, [1, K.shape(seq)[1], 1]) return K.concatenate([seq, vec], 2)