def double_linear_logits(args, size, bias, bias_start=0.0, scope=None, mask=None, wd=0.0, input_keep_prob=1.0, is_train=None): with tf.variable_scope(scope or "Double_Linear_Logits"): first = tf.tanh( linear(args, size, bias, bias_start=bias_start, scope='first', wd=wd, input_keep_prob=input_keep_prob, is_train=is_train)) second = linear(first, 1, bias, bias_start=bias_start, squeeze=True, scope='second', wd=wd, input_keep_prob=input_keep_prob, is_train=is_train) if mask is not None: second = exp_mask(second, mask) return second
def linear_logits(args, bias, bias_start=0.0, scope=None, mask=None, wd=0.0, input_keep_prob=1.0, is_train=None): with tf.variable_scope(scope or "Linear_Logits"): logits = linear(args, 1, bias, bias_start=bias_start, squeeze=True, scope='first', wd=wd, input_keep_prob=input_keep_prob, is_train=is_train) if mask is not None: logits = exp_mask(logits, mask) return logits
def sum_logits(args, mask=None, name=None): with tf.name_scope(name or "sum_logits"): if args is None or (isinstance(args, (tuple, list)) and not args): raise ValueError("`args` must be specified") if not isinstance(args, (tuple, list)): args = [args] rank = len(args[0].get_shape()) logits = sum(tf.reduce_sum(arg, rank - 1) for arg in args) if mask is not None: logits = exp_mask(logits, mask) return logits
def softmax(logits, mask=None, scope=None): with tf.name_scope(scope or "Softmax"): if mask is not None: logits = exp_mask(logits, mask) out = tf.nn.softmax(logits, -1) return out
def multi_head_attention(rep_tensor, rep_mask, head_num=8, hidden_units_num=64, scope=None, is_train=None, keep_prob=1., wd=0.): bs, sl, vec = tf.shape(rep_tensor)[0], tf.shape(rep_tensor)[1], tf.shape( rep_tensor)[2] ivec = rep_tensor.get_shape().as_list()[2] with tf.variable_scope(scope or 'multi_head_attention'): with tf.variable_scope('positional_encoding'): seq_idxs = tf.tile(tf.expand_dims(tf.range(sl), 1), [1, ivec]) # sl, ivec feature_idxs = tf.tile(tf.expand_dims(tf.range(ivec), 0), [sl, 1]) # sl, ivec pos_enc = tf.where( tf.equal(tf.mod(feature_idxs, 2), 0), tf.sin( tf.cast(seq_idxs, tf.float32) / tf.pow( 10000., 2.0 * tf.cast(feature_idxs, tf.float32) / (1.0 * ivec))), tf.cos( tf.cast(seq_idxs, tf.float32) / tf.pow( 10000., 2.0 * tf.cast(feature_idxs - 1, tf.float32) / (1.0 * ivec))), ) rep_tensor_pos = mask_for_high_rank(rep_tensor + pos_enc, rep_mask) # bs, sl, ivec with tf.variable_scope('multi_head_attention'): W = tf.get_variable('W', [3, head_num, ivec, hidden_units_num], tf.float32) rep_tile = tf.tile( tf.expand_dims(tf.expand_dims(rep_tensor_pos, 0), 0), [3, head_num, 1, 1, 1]) # 3,head_num,bs,sl,ivec rep_tile_reshape = tf.reshape( rep_tile, [3, head_num, bs * sl, ivec]) # head_num,bs*sl,ivec maps = tf.reshape( # 3,head_num,bs*sl,hn -> 3,head_num,bs,sl,hn tf.matmul(dropout(rep_tile_reshape, keep_prob, is_train), W), [3, head_num, bs, sl, hidden_units_num]) Q_map, K_map, V_map = tf.split(maps, 3, 0) Q_map = tf.squeeze(Q_map, [0]) # head_num,bs,sl,hn K_map = tf.squeeze(K_map, [0]) # head_num,bs,sl,hn V_map = tf.squeeze(V_map, [0]) # head_num,bs,sl,hn # head_num,bs,sl,sl # similarity_mat = tf.reduce_sum(Q_map_tile * K_map_tile, -1) / math.sqrt(1. * hidden_units_num) similarity_mat = tf.matmul(Q_map, tf.transpose( K_map, [0, 1, 3, 2])) / math.sqrt(1. * hidden_units_num) # mask: bs,sl -> head_num,bs,sl multi_mask = tf.tile(tf.expand_dims(rep_mask, 0), [head_num, 1, 1]) # head_num,bs,sl multi_mask_tile_1 = tf.expand_dims(multi_mask, 2) # head_num,bs,1,sl multi_mask_tile_2 = tf.expand_dims(multi_mask, 3) # head_num,bs,sl,1 multi_mask_tile = tf.logical_and( multi_mask_tile_1, multi_mask_tile_2) # head_num,bs,sl,sl similarity_mat_masked = exp_mask( similarity_mat, multi_mask_tile) # head_num,bs,sl,sl prob_dist = tf.nn.softmax( similarity_mat_masked) # head_num,bs,sl,sl prob_dist_dp = dropout(prob_dist, keep_prob, is_train) attn_res = tf.matmul(prob_dist_dp, V_map) # head_num,bs,sl,hn attn_res_tran = tf.transpose(attn_res, [1, 2, 0, 3]) output = tf.reshape(attn_res_tran, [bs, sl, head_num * hidden_units_num]) if wd > 0.: add_reg_without_bias() return output