Esempio n. 1
0
def double_linear_logits(args,
                         size,
                         bias,
                         bias_start=0.0,
                         scope=None,
                         mask=None,
                         wd=0.0,
                         input_keep_prob=1.0,
                         is_train=None):
    with tf.variable_scope(scope or "Double_Linear_Logits"):
        first = tf.tanh(
            linear(args,
                   size,
                   bias,
                   bias_start=bias_start,
                   scope='first',
                   wd=wd,
                   input_keep_prob=input_keep_prob,
                   is_train=is_train))
        second = linear(first,
                        1,
                        bias,
                        bias_start=bias_start,
                        squeeze=True,
                        scope='second',
                        wd=wd,
                        input_keep_prob=input_keep_prob,
                        is_train=is_train)
        if mask is not None:
            second = exp_mask(second, mask)
        return second
Esempio n. 2
0
def linear_logits(args, bias, bias_start=0.0, scope=None, mask=None, wd=0.0, input_keep_prob=1.0, is_train=None):
    with tf.variable_scope(scope or "Linear_Logits"):
        logits = linear(args, 1, bias, bias_start=bias_start, squeeze=True, scope='first',
                        wd=wd, input_keep_prob=input_keep_prob, is_train=is_train)
        if mask is not None:
            logits = exp_mask(logits, mask)
        return logits
Esempio n. 3
0
def sum_logits(args, mask=None, name=None):
    with tf.name_scope(name or "sum_logits"):
        if args is None or (isinstance(args, (tuple, list)) and not args):
            raise ValueError("`args` must be specified")
        if not isinstance(args, (tuple, list)):
            args = [args]
        rank = len(args[0].get_shape())
        logits = sum(tf.reduce_sum(arg, rank - 1) for arg in args)
        if mask is not None:
            logits = exp_mask(logits, mask)
        return logits
Esempio n. 4
0
def softmax(logits, mask=None, scope=None):
    with tf.name_scope(scope or "Softmax"):
        if mask is not None:
            logits = exp_mask(logits, mask)
        out = tf.nn.softmax(logits, -1)
        return out
Esempio n. 5
0
def multi_head_attention(rep_tensor,
                         rep_mask,
                         head_num=8,
                         hidden_units_num=64,
                         scope=None,
                         is_train=None,
                         keep_prob=1.,
                         wd=0.):
    bs, sl, vec = tf.shape(rep_tensor)[0], tf.shape(rep_tensor)[1], tf.shape(
        rep_tensor)[2]
    ivec = rep_tensor.get_shape().as_list()[2]

    with tf.variable_scope(scope or 'multi_head_attention'):

        with tf.variable_scope('positional_encoding'):
            seq_idxs = tf.tile(tf.expand_dims(tf.range(sl), 1),
                               [1, ivec])  # sl, ivec
            feature_idxs = tf.tile(tf.expand_dims(tf.range(ivec), 0),
                                   [sl, 1])  # sl, ivec
            pos_enc = tf.where(
                tf.equal(tf.mod(feature_idxs, 2), 0),
                tf.sin(
                    tf.cast(seq_idxs, tf.float32) / tf.pow(
                        10000., 2.0 * tf.cast(feature_idxs, tf.float32) /
                        (1.0 * ivec))),
                tf.cos(
                    tf.cast(seq_idxs, tf.float32) / tf.pow(
                        10000., 2.0 * tf.cast(feature_idxs - 1, tf.float32) /
                        (1.0 * ivec))),
            )
            rep_tensor_pos = mask_for_high_rank(rep_tensor + pos_enc,
                                                rep_mask)  # bs, sl, ivec

        with tf.variable_scope('multi_head_attention'):
            W = tf.get_variable('W', [3, head_num, ivec, hidden_units_num],
                                tf.float32)
            rep_tile = tf.tile(
                tf.expand_dims(tf.expand_dims(rep_tensor_pos, 0), 0),
                [3, head_num, 1, 1, 1])  # 3,head_num,bs,sl,ivec
            rep_tile_reshape = tf.reshape(
                rep_tile, [3, head_num, bs * sl, ivec])  # head_num,bs*sl,ivec

            maps = tf.reshape(  # 3,head_num,bs*sl,hn ->  3,head_num,bs,sl,hn
                tf.matmul(dropout(rep_tile_reshape, keep_prob, is_train), W),
                [3, head_num, bs, sl, hidden_units_num])
            Q_map, K_map, V_map = tf.split(maps, 3, 0)
            Q_map = tf.squeeze(Q_map, [0])  # head_num,bs,sl,hn
            K_map = tf.squeeze(K_map, [0])  # head_num,bs,sl,hn
            V_map = tf.squeeze(V_map, [0])  # head_num,bs,sl,hn

            # head_num,bs,sl,sl
            # similarity_mat = tf.reduce_sum(Q_map_tile * K_map_tile, -1) / math.sqrt(1. * hidden_units_num)
            similarity_mat = tf.matmul(Q_map, tf.transpose(
                K_map, [0, 1, 3, 2])) / math.sqrt(1. * hidden_units_num)

            # mask: bs,sl -> head_num,bs,sl
            multi_mask = tf.tile(tf.expand_dims(rep_mask, 0),
                                 [head_num, 1, 1])  # head_num,bs,sl
            multi_mask_tile_1 = tf.expand_dims(multi_mask,
                                               2)  # head_num,bs,1,sl
            multi_mask_tile_2 = tf.expand_dims(multi_mask,
                                               3)  # head_num,bs,sl,1
            multi_mask_tile = tf.logical_and(
                multi_mask_tile_1, multi_mask_tile_2)  # head_num,bs,sl,sl
            similarity_mat_masked = exp_mask(
                similarity_mat, multi_mask_tile)  # head_num,bs,sl,sl
            prob_dist = tf.nn.softmax(
                similarity_mat_masked)  # head_num,bs,sl,sl
            prob_dist_dp = dropout(prob_dist, keep_prob, is_train)

            attn_res = tf.matmul(prob_dist_dp, V_map)  # head_num,bs,sl,hn

            attn_res_tran = tf.transpose(attn_res, [1, 2, 0, 3])
            output = tf.reshape(attn_res_tran,
                                [bs, sl, head_num * hidden_units_num])

            if wd > 0.:
                add_reg_without_bias()

            return output