Пример #1
0
def bn_dense_layer(input_tensors,
                   hn,
                   bias,
                   bias_start=0.0,
                   scope=None,
                   activation='relu',
                   enable_bn=False,
                   wd=0.,
                   keep_prob=1.0,
                   is_train=None):
    tf.logging.warning(
        "Please use \"bn_dense_layer_v2\" rather than \"bn_dense_layer\" for future support! "
    )
    with tf.variable_scope(scope or 'bn_dense_layer'):
        linear_map = linear(input_tensors, hn, bias, bias_start, 'linear_map',
                            False, wd, keep_prob, is_train)
        if enable_bn:
            linear_map = tf.contrib.layers.batch_norm(linear_map,
                                                      center=True,
                                                      scale=True,
                                                      is_training=is_train,
                                                      updates_collections=None,
                                                      decay=0.9,
                                                      scope='bn')
        act_fn = act_name2fn(activation)
        return act_fn(linear_map)
Пример #2
0
def logits_for_sketch_prediction(decoder_states,
                                 cls_state,
                                 num_channel,
                                 hn=None,
                                 act_name="relu",
                                 wd=0.,
                                 keep_prob=1.0,
                                 is_train=None,
                                 compress_mask=None,
                                 scope=None):
    compressing = not isinstance(compress_mask, type(None))
    hn = hn or get_shape_list(decoder_states)[-1]
    with tf.variable_scope(scope or "logits_for_sketch_index"):
        if compressing:
            new_decoder_states, _, rev_d = compress_seq_wrt_mask(
                decoder_states, compress_mask)
        else:
            new_decoder_states = decoder_states
            rev_d = None
        map_part1 = bn_dense_layer_v2(new_decoder_states, hn, True, 0.,
                                      "map_part1", "linear", False, wd,
                                      keep_prob, is_train)
        map_part2_pre = bn_dense_layer_v2(cls_state, hn, False, 0.,
                                          "map_part2_pre", "linear", False, wd,
                                          keep_prob, is_train)
        map_part2 = tf.tile(tf.expand_dims(map_part2_pre, axis=1),
                            [1, get_shape_list(map_part1)[1], 1])
        map_res = act_name2fn(act_name)(map_part1 + map_part2)

        logits = bn_dense_layer_v2(map_res, num_channel, True, 0., "logits",
                                   "linear", False, wd, keep_prob, is_train)
        if compressing:
            logits = decompress_seq_wrt_mask(logits, rev_d)
        return logits
Пример #3
0
def mlp(x, scope, n_state, train=None, afn='gelu', resid_dropout=0.9):  # read: 3layer mlp
    with tf.variable_scope(scope):
        nx = shape_list(x)[-1]
        act = act_name2fn(afn)
        h = act(conv1d(x, 'c_fc_openai_trans', n_state, 1, train=train))
        h2 = conv1d(h, 'c_proj_openai_trans', nx, 1, train=train)
        h2 = dropout(h2, resid_dropout, train)
        return h2
Пример #4
0
def qqp_logits_sentence_encoding(s1_rep, s2_rep, afn, n_state, is_train, clf_dropout, highway=False):   # TODO: change this to my style (bn_dense_layer)
    out_rep = tf.concat([tf.abs(s1_rep - s2_rep), s1_rep * s2_rep], -1)
    act = act_name2fn(afn)
    h = act(conv1d(out_rep, 'c_fc', n_state, 1, train=is_train))

    if highway:
        trans = conv1d(h, 'c_trans', n_state, 1, train=is_train)
        gate = tf.nn.sigmoid(conv1d(h, 'c_gate', n_state, 1, train=is_train))
        h = gate * trans + (1 - gate) * h

    h_dp = dropout(h, clf_dropout, is_train)
    return conv1d(h_dp, 'c_logits', 2, 1, train=is_train)
Пример #5
0
def bn_dense_layer_multi_head(input_tensor,
                              hn,
                              bias,
                              bias_start=0.0,
                              scope=None,
                              activation='relu',
                              enable_bn=False,
                              wd=0.,
                              keep_prob=1.0,
                              is_train=None,
                              dup_num=1,
                              merge_var=False):
    assert not enable_bn
    """The input could be >3-d and the 1d-for bs, 2d for head, -1d for hn"""

    act_fn = act_name2fn(activation)

    with tf.variable_scope(scope or 'bn_dense_layer_multi_head'):
        input_tensor = dropout(input_tensor, keep_prob,
                               is_train)  # dropout [bs,hd,sl,dim]
        # the comments using 4d [bs,hd,sl,dim] for example
        input_shape = get_shape_list(input_tensor)  # [4] for [bs,hd,sl,dim]
        assert len(input_shape) >= 3
        # exchange 1st and 2nd dimension
        perm_t = list(range(len(input_shape)))  # [0,1,2,3]
        perm_t[0], perm_t[1] = perm_t[1], perm_t[0]  # [1,0,2,3]
        input_tensor_t = tf.transpose(input_tensor, perm_t)  # [hd,bs,sl,dim]

        # merge and reshape
        input_shape_t = get_shape_list(
            input_tensor_t)  # [4] for [hd,bs,sl,dim]
        dims_merge = input_shape_t[1:-1]  # [2] for [bs,sl]
        new_dim = reduce(mul, dims_merge)  # bs*sl
        new_shape = [input_shape_t[0], new_dim,
                     input_shape_t[-1]]  # [3] for [hd,bs*sl,dim]
        input_tensor_rsp = tf.reshape(input_tensor_t,
                                      new_shape)  # [hd,bs*sl,dim]

        # dense layer
        hd_num = new_shape[0]  # head num
        hd_dim = new_shape[-1]  # head dim

        if merge_var:
            weight = tf.get_variable('W', shape=[hd_num, hd_dim, hn * dup_num])
        else:
            weight_list = []
            for i in range(hd_num):
                sub_weight_list = []
                for j in range(dup_num):
                    sub_weight_list.append(
                        tf.get_variable('W_%d_%d' % (i, j), shape=[hd_dim,
                                                                   hn]))
                weight_list.append(
                    tf.concat(sub_weight_list, -1
                              ) if dup_num > 1 else sub_weight_list[0])
            weight = tf.stack(weight_list, 0)

        out_rsp = tf.matmul(input_tensor_rsp, weight)  # hd_num, bs*sl, hn
        if bias:
            if merge_var:
                bias_val = tf.get_variable(
                    'bias',
                    shape=[hd_num, 1, hn],
                    dtype=tf.float32,
                    initializer=tf.constant_initializer(bias_start))
            else:
                bias_list = []
                for i in range(hd_num):
                    sub_bias_list = []
                    for j in range(dup_num):
                        sub_bias_list.append(
                            tf.get_variable(
                                'bias_%d_%d' % (i, j),
                                shape=[1, hn],
                                dtype=tf.float32,
                                initializer=tf.constant_initializer(
                                    bias_start)))
                    bias_list.append(
                        tf.concat(sub_bias_list, -1
                                  ) if dup_num > 1 else sub_bias_list[0])
                bias_val = tf.stack(bias_list, 0)
            out_rsp = out_rsp + bias_val  # hd_num, bs*sl, hn

        # un-merge
        output_shape_t = [new_shape[0]
                          ] + dims_merge + [hn]  # [4] for [hd,bs,sl,new_dim]
        output_t = tf.reshape(out_rsp, output_shape_t)  # [hd,bs,sl,new_dim]

        # transpose
        output = tf.transpose(output_t, perm_t)  # [bs,hd,sl,new_dim]

        if wd:
            tf.add_to_collection('reg_vars', weight)

        return act_fn(output)
Пример #6
0
def bn_dense_layer_v2(input_tensor,
                      hn,
                      bias,
                      bias_start=0.0,
                      scope=None,
                      activation='relu',
                      enable_bn=False,
                      wd=0.,
                      keep_prob=1.0,
                      is_train=None,
                      dup_num=1,
                      merge_var=False):
    act_fn = act_name2fn(activation)
    with tf.variable_scope(scope or 'bn_dense_layer'):
        input_tensor = dropout(input_tensor, keep_prob, is_train)
        # the comment use a 3d tensor [bs,sl,hn] as a example
        input_shape = get_shape_list(input_tensor)  # [3]
        assert len(input_shape) >= 2  # at least [bs,hn]
        # merge
        dims_merge = input_shape[:-1]  # [all unrelated dims]
        new_dim = reduce(mul, dims_merge)  # get the merged dim
        new_shape = [new_dim, input_shape[-1]]  # new shape for matmul [2]
        input_tensor_rsp = tf.reshape(input_tensor, new_shape)  #  [xx,dim]

        # dense layer
        input_dim = new_shape[-1]
        if merge_var:
            weight = tf.get_variable('W',
                                     shape=[input_dim, hn * dup_num],
                                     dtype=tf.float32)
        else:
            weight_list = []
            for i in range(dup_num):
                weight_list.append(
                    tf.get_variable('W_%d' % i, shape=[input_dim, hn]))
            weight = tf.concat(weight_list, -1)
        output_rsp = tf.matmul(input_tensor_rsp, weight)

        if bias:
            if merge_var or dup_num == 1:
                bias_val = tf.get_variable(
                    'bias',
                    shape=[hn * dup_num],
                    dtype=tf.float32,
                    initializer=tf.constant_initializer(bias_start))
            else:
                bias_list = []
                for i in range(dup_num):
                    bias_list.append(
                        tf.get_variable(
                            'bias_%d' % i,
                            shape=[hn],
                            dtype=tf.float32,
                            initializer=tf.constant_initializer(bias_start)))
                bias_val = tf.concat(bias_list, -1)
            output_rsp += bias_val

        # output reshape
        output_shape = dims_merge + [hn * dup_num]  # [3] for [bs,sl,new_hn]
        output = tf.reshape(output_rsp, output_shape)  # [bs,sl,new_hn]

        if enable_bn:
            output = tf.contrib.layers.batch_norm(output,
                                                  center=True,
                                                  scale=True,
                                                  is_training=is_train,
                                                  updates_collections=None,
                                                  decay=0.9,
                                                  scope='bn')

        if wd:
            tf.add_to_collection('reg_vars', weight)

        return act_fn(output)
Пример #7
0
def compatibility_fn_lacacy(  # did not support arbitrary dim
        tensor_from, tensor_to, method='dot_product', scope=None, **kwargs):

    def _get_val_from_kwargs(key, default_val):
        if key in kwargs:
            return kwargs[key]
        else:
            return default_val

    with tf.variable_scope(scope or 'compatibility_fn.{}'.format(method)):
        shape_from = get_shape_list(tensor_from)
        ndim_from = len(shape_from)
        shape_to = get_shape_list(tensor_to)
        ndim_to = len(shape_to)

        assert (ndim_from == 2 or ndim_from == 3) and ndim_to == 3

        if ndim_from == 2:
            tensor_from = tf.expand_dims(tensor_from, 1)
            shape_from = get_shape_list(tensor_from)

        slf, slt = shape_from[1], shape_to[1]

        # hparams parsing
        hn = _get_val_from_kwargs('hn', shape_to[-1])
        wd = _get_val_from_kwargs('wd', 0.)
        keep_prob = _get_val_from_kwargs('keep_prob', 1.)
        is_training = _get_val_from_kwargs('is_training', None)
        activation = _get_val_from_kwargs('activation', 'relu')
        head_num = _get_val_from_kwargs('head_num', 12)

        seq_dim_to_remove = 1
        if method == 'dot_product':
            align_scores = tf.matmul(tensor_from, tensor_to, transpose_b=True)  # [bs,slf,hn]*[bs,slt,hn]=>bs,slf,slt
            align_scores = tf.expand_dims(align_scores, -1)  # [bs,slf,slt,1]
        elif method == 'additive':
            tensor_from_branch = bn_dense_layer_v2(
                tensor_from, hn, False, 0., 'tensor_from_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            tensor_to_branch = bn_dense_layer_v2(
                tensor_to, hn, True, 0., 'tensor_to_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            align_scores_pre = act_name2fn(activation)(tf.add(  # [bs,slf,slt,hn]
                tf.expand_dims(tensor_from_branch, 2),  # [bs,slf,1,hn]
                tf.expand_dims(tensor_to_branch, 1)  # [bs,1,slt,hn]
            ))
            align_scores = bn_dense_layer_v2(  # [bs,slf,slt,1]
                align_scores_pre, 1, True, 0., 'align_scores', 'linear', False,
                wd, keep_prob, is_training
            )
        elif method == 'multi_dim':
            logging.warning("No simplified multi-dim technique used in this function!")
            tensor_from_branch = bn_dense_layer_v2(
                tensor_from, hn, False, 0., 'tensor_from_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            tensor_to_branch = bn_dense_layer_v2(
                tensor_to, hn, True, 0., 'tensor_to_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            align_scores_pre = act_name2fn(activation)(tf.add(  # [bs,slf,slt,hn]
                tf.expand_dims(tensor_from_branch, 2),  # [bs,slf,1,hn]
                tf.expand_dims(tensor_to_branch, 1)  # bs,1,slt,hn
            ))
            align_scores = bn_dense_layer_v2(
                align_scores_pre, hn, True, 0., 'align_score', 'linear', False,
                wd, keep_prob, is_training
            )
        elif method == 'multi_head':
            seq_dim_to_remove = 2  # !!! because multi-head dim is on 2nd dim
            assert hn % head_num == 0
            head_dim = hn // head_num

            q_heads = bn_dense_layer_v2(
                tensor_from, head_dim, True, 0., 'q_heads',
                'linear', False, wd, keep_prob, is_training, dup_num=head_num
            )
            k_heads = bn_dense_layer_v2(
                tensor_to, head_dim, True, 0., 'k_heads',
                'linear', False, wd, keep_prob, is_training, dup_num=head_num
            )
            q_heads = split_head(q_heads, head_num)  # bs,hd_num,slf,hd_dim
            k_heads = split_head(k_heads, head_num)  # bs,hd_num,slt,hd_dim

            # alignment score
            align_scores = tf.matmul(q_heads, k_heads, transpose_b=True)  # [bs,hd_num,slf,slt]
            align_scores = align_scores / math.sqrt(1.*head_dim)  # [bs,hd_num,slf,slt]
        elif method == 'multi_dim_head':
            seq_dim_to_remove = 2  # !!! because multi-head dim is on 2nd dim
            assert hn % head_num == 0
            head_dim = hn // head_num

            q_heads = bn_dense_layer_v2(
                tensor_from, head_dim, True, 0., 'q_heads',
                'linear', False, wd, keep_prob, is_training, dup_num=head_num
            )
            k_heads = bn_dense_layer_v2(
                tensor_to, head_dim, True, 0., 'k_heads',
                'linear', False, wd, keep_prob, is_training, dup_num=head_num
            )
            q_heads = split_head(q_heads, head_num)  # bs,hd_num,slf,hd_dim
            k_heads = split_head(k_heads, head_num)  # bs,hd_num,slt,hd_dim

            # MLP
            q_heads_branch = bn_dense_layer_multi_head(
                q_heads, head_dim, False, 0., 'q_heads_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            k_heads_branch = bn_dense_layer_multi_head(
                k_heads, head_dim, True, 0., 'k_heads_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            align_scores_pre = act_name2fn(activation)(tf.add(  # [bs,head,slf,slt,dim]
                tf.expand_dims(q_heads_branch, 3),  # [bs,head,slf,1,dim]
                tf.expand_dims(k_heads_branch, 2)  # bs,head,1,slt,dim
            ))
            align_scores_heads = bn_dense_layer_multi_head(  # [bs,hd_num,slf,slt,hd_dim]
                align_scores_pre, head_dim, True, 0., 'align_scores_heads', 'linear', False,
                wd, keep_prob, is_training
            )
            align_scores = align_scores_heads  # [bs,hd_num,slf,slt,hd_dim]
            # align_scores = combine_head(align_scores_heads)
        elif method == 'bilinear':
            raise NotImplementedError
        else:
            raise AttributeError

        if ndim_from == 2:
            align_scores = tf.squeeze(align_scores, [seq_dim_to_remove])  #

        return align_scores
Пример #8
0
def s2t_self_attn(
        tensor_input, tensor_mask, deep_act=None, method='multi_dim',
        wd=0., keep_prob=1., is_training=None,
        scope=None, **kwargs
):
    use_deep = isinstance(deep_act, str)  # use Two layers or Single layer for the alignment score
    with tf.variable_scope(scope or 's2t_self_attn_{}'.format(method)):
        tensor_shape = get_shape_list(tensor_input)
        hn = tensor_shape[-1]  # hidden state number

        if method == 'additive':
            align_scores = bn_dense_layer_v2(  # bs,sl,hn/1
                tensor_input, hn if use_deep else 1, True, 0., 'align_score_1', 'linear', False,
                wd, keep_prob, is_training
            )
            if use_deep:
                align_scores = bn_dense_layer_v2(  # bs,sl,1
                    act_name2fn(deep_act)(align_scores), 1, True, 0., 'align_score_2', 'linear', False,
                    wd, keep_prob, is_training
                )
        elif method == 'multi_dim':
            align_scores = bn_dense_layer_v2(  # bs,sl,hn
                tensor_input, hn, False, 0., 'align_score_1', 'linear', False,
                wd, keep_prob, is_training
            )
            if use_deep:
                align_scores = bn_dense_layer_v2(  # bs,sl,hn
                    act_name2fn(deep_act)(align_scores), hn, True, 0., 'align_score_2', 'linear', False,
                    wd, keep_prob, is_training
                )
        elif method == 'multi_dim_head':
            get_shape_list(tensor_input, expected_rank=3)  # the input should be rank-3
            assert 'head_num' in kwargs and isinstance(kwargs['head_num'], int)
            head_num = kwargs['head_num']
            assert hn % head_num == 0
            head_dim = hn // head_num

            tensor_input_heads = split_head(tensor_input, head_num)  # [bs,hd,sl,hd_dim]

            align_scores_heads = bn_dense_layer_multi_head(  # [bs,hd,sl,hd_dim]
                tensor_input_heads, head_dim, True, 0., 'align_scores_heads_1', 'linear', False,
                wd, keep_prob, is_training
            )
            if use_deep:
                align_scores_heads = bn_dense_layer_multi_head(  # [bs,hd,sl,hd_dim]
                    act_name2fn(deep_act)(align_scores_heads), head_dim,
                    True, 0., 'align_scores_heads_2', 'linear', False,
                    wd, keep_prob, is_training
                )
            align_scores = combine_head(align_scores_heads)  # [bs,sl,dim]
        else:
            raise AttributeError

        # attention procedure align_scores [bs,sl,1/dim]
        align_scores_masked = exp_mask_v3(align_scores, tensor_mask, multi_head=False, high_dim=True)  # bs,sl,hn
        attn_prob = tf.nn.softmax(align_scores_masked, axis=-2)  # bs,sl,hn

        if 'attn_keep_prob' in kwargs and isinstance(kwargs['attn_keep_prob'], float):
            attn_prob = dropout(attn_prob, kwargs['attn_keep_prob'], is_training)  # bs,sl,hn

        attn_res = tf.reduce_sum(  # [bs,sl,hn] -> [bs,dim]
            mask_v3(attn_prob*tensor_input, tensor_mask, high_dim=True), axis=-2
        )

        return attn_res  # [bs,hn]
Пример #9
0
def compatibility_fn(tensor_from, tensor_to, method='dot_product', scope=None, **kwargs):
    def _get_val_from_kwargs(key, default_val):
        if key in kwargs:
            return kwargs[key]
        else:
            return default_val

    with tf.variable_scope(scope or 'compatibility_fn.{}'.format(method)):
        shape_from = get_shape_list(tensor_from)
        ndim_from = len(shape_from)
        shape_to = get_shape_list(tensor_to)
        ndim_to = len(shape_to)

        assert ndim_from == ndim_to or ndim_from+1 == ndim_to
        need_extra_dim = ndim_from+1 == ndim_to

        if need_extra_dim:
            tensor_from = tf.expand_dims(tensor_from, -2)
            shape_from = get_shape_list(tensor_from)

        slf, slt = shape_from[-2], shape_to[-2]

        # hparams parsing
        hn = _get_val_from_kwargs('hn', shape_to[-1])
        wd = _get_val_from_kwargs('wd', 0.)
        keep_prob = _get_val_from_kwargs('keep_prob', 1.)
        is_training = _get_val_from_kwargs('is_training', None)
        activation = _get_val_from_kwargs('activation', 'relu')
        head_num = _get_val_from_kwargs('head_num', 12)

        seq_dim_to_remove = -3
        if method == 'dot_product':
            align_scores = tf.matmul(tensor_from, tensor_to, transpose_b=True)  # [bs,slf,hn]*[bs,slt,hn]=>bs,slf,slt
            align_scores = tf.expand_dims(align_scores, -1)  # [bs,slf,slt,1]
        elif method == 'additive':
            tensor_from_branch = bn_dense_layer_v2(
                tensor_from, hn, False, 0., 'tensor_from_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            tensor_to_branch = bn_dense_layer_v2(
                tensor_to, hn, True, 0., 'tensor_to_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            align_scores_pre = act_name2fn(activation)(tf.add(  # [bs,slf,slt,hn]
                tf.expand_dims(tensor_from_branch, -2),  # [bs,slf,1,hn]
                tf.expand_dims(tensor_to_branch, -3)  # [bs,1,slt,hn]
            ))
            align_scores = bn_dense_layer_v2(  # [bs,slf,slt,1]
                align_scores_pre, 1, True, 0., 'align_scores', 'linear', False,
                wd, keep_prob, is_training
            )
        elif method == 'multi_dim':
            logging.warning("No simplified multi-dim technique used in this function!")
            tensor_from_branch = bn_dense_layer_v2(
                tensor_from, hn, False, 0., 'tensor_from_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            tensor_to_branch = bn_dense_layer_v2(
                tensor_to, hn, True, 0., 'tensor_to_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            align_scores_pre = act_name2fn(activation)(tf.add(  # [bs,slf,slt,hn]
                tf.expand_dims(tensor_from_branch, -2),  # [bs,slf,1,hn]
                tf.expand_dims(tensor_to_branch, -3)  # bs,1,slt,hn
            ))
            align_scores = bn_dense_layer_v2(  # [bs,slf,slt,hn]
                align_scores_pre, hn, True, 0., 'align_score', 'linear', False,
                wd, keep_prob, is_training
            )
        elif method == 'multi_head':
            seq_dim_to_remove = -2  # !!! because multi-head dim is on 2nd dim
            assert hn % head_num == 0
            head_dim = hn // head_num

            q_heads = bn_dense_layer_v2(
                tensor_from, head_dim, True, 0., 'q_heads',
                'linear', False, wd, keep_prob, is_training, dup_num=head_num
            )
            k_heads = bn_dense_layer_v2(
                tensor_to, head_dim, True, 0., 'k_heads',
                'linear', False, wd, keep_prob, is_training, dup_num=head_num
            )
            q_heads = split_head(q_heads, head_num)  # bs,hd_num,slf,hd_dim
            k_heads = split_head(k_heads, head_num)  # bs,hd_num,slt,hd_dim

            # alignment score
            align_scores = tf.matmul(q_heads, k_heads, transpose_b=True)  # [bs,hd_num,slf,slt]
            align_scores = align_scores / math.sqrt(1.*head_dim)  # [bs,hd_num,slf,slt]
        elif method in ['multi_head_bilinear', 'multi_head_bilinear_shared', 'multi_head_only', 'multi_head_linear']:
            seq_dim_to_remove = -2  # !!! because multi-head dim is on 2nd dim
            assert hn % head_num == 0
            head_dim = hn // head_num

            q_heads = bn_dense_layer_v2(
                tensor_from, head_dim, True, 0., 'q_heads',
                kwargs.get("activation") or activation,
                False, wd, keep_prob, is_training, dup_num=head_num
            )
            k_heads = bn_dense_layer_v2(
                tensor_to, head_dim, True, 0., 'k_heads',
                kwargs.get("activation") or activation,
                False, wd, keep_prob, is_training, dup_num=head_num
            )
            q_heads = split_head(q_heads, head_num)  # bs,hd_num,slf,hd_dim
            k_heads = split_head(k_heads, head_num)  # bs,hd_num,slt,hd_dim

            # alignment score: using biliear rather than dot product
            # align_scores = tf.matmul(q_heads, k_heads, transpose_b=True)  # [bs,hd_num,slf,slt]
            # align_scores = align_scores / math.sqrt(1. * head_dim)  # [bs,hd_num,slf,slt]
            with tf.variable_scope("bilinear"):
                if method == "multi_head_bilinear":
                    k_heads_map = bn_dense_layer_multi_head(
                        k_heads, head_dim, False, 0., 'k_heads_map', 'linear', False, wd, keep_prob, is_training)
                elif method == "multi_head_bilinear_shared":
                    k_heads_map = bn_dense_layer_v2(
                        k_heads, head_dim, False, 0., 'k_heads_map', 'linear', False, wd, keep_prob, is_training)
                elif method == "multi_head_only":
                    pass
                elif method == "multi_head_linear":
                    k_heads_map = bn_dense_layer_v2(
                        k_heads, head_dim, False, 0., 'k_heads_map', 'linear', False, wd, keep_prob, is_training)
                    q_heads_map = bn_dense_layer_v2(
                        q_heads, head_dim, False, 0., 'q_heads_map', 'linear', False, wd, keep_prob, is_training)
                else:
                    raise AttributeError
                align_scores = tf.matmul(q_heads, k_heads, transpose_b=True)

                log_specific_params()


        elif method == 'multi_dim_head':
            assert hn % head_num == 0
            head_dim = hn // head_num

            q_heads = bn_dense_layer_v2(
                tensor_from, head_dim, True, 0., 'q_heads',
                'linear', False, wd, keep_prob, is_training, dup_num=head_num
            )
            k_heads = bn_dense_layer_v2(
                tensor_to, head_dim, True, 0., 'k_heads',
                'linear', False, wd, keep_prob, is_training, dup_num=head_num
            )
            q_heads = split_head(q_heads, head_num)  # bs,hd_num,slf,hd_dim
            k_heads = split_head(k_heads, head_num)  # bs,hd_num,slt,hd_dim

            # MLP
            q_heads_branch = bn_dense_layer_multi_head(
                q_heads, head_dim, False, 0., 'q_heads_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            k_heads_branch = bn_dense_layer_multi_head(
                k_heads, head_dim, True, 0., 'k_heads_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            align_scores_pre = act_name2fn(activation)(tf.add(  # [bs,head,slf,slt,dim]
                tf.expand_dims(q_heads_branch, -2),  # [bs,head,slf,1,dim]
                tf.expand_dims(k_heads_branch, -3)  # bs,head,1,slt,dim
            ))
            align_scores_heads = bn_dense_layer_multi_head(  # [bs,hd_num,slf,slt,hd_dim]
                align_scores_pre, head_dim, True, 0., 'align_scores_heads', 'linear', False,
                wd, keep_prob, is_training
            )
            align_scores = align_scores_heads  # [bs,hd_num,slf,slt,hd_dim]
        elif method == 'bilinear':
            raise NotImplementedError
        else:
            raise AttributeError

        if need_extra_dim:
            align_scores = tf.squeeze(align_scores, [seq_dim_to_remove])  #

        return align_scores