Example #1
0
def logits_for_sketch_prediction(decoder_states,
                                 cls_state,
                                 num_channel,
                                 hn=None,
                                 act_name="relu",
                                 wd=0.,
                                 keep_prob=1.0,
                                 is_train=None,
                                 compress_mask=None,
                                 scope=None):
    compressing = not isinstance(compress_mask, type(None))
    hn = hn or get_shape_list(decoder_states)[-1]
    with tf.variable_scope(scope or "logits_for_sketch_index"):
        if compressing:
            new_decoder_states, _, rev_d = compress_seq_wrt_mask(
                decoder_states, compress_mask)
        else:
            new_decoder_states = decoder_states
            rev_d = None
        map_part1 = bn_dense_layer_v2(new_decoder_states, hn, True, 0.,
                                      "map_part1", "linear", False, wd,
                                      keep_prob, is_train)
        map_part2_pre = bn_dense_layer_v2(cls_state, hn, False, 0.,
                                          "map_part2_pre", "linear", False, wd,
                                          keep_prob, is_train)
        map_part2 = tf.tile(tf.expand_dims(map_part2_pre, axis=1),
                            [1, get_shape_list(map_part1)[1], 1])
        map_res = act_name2fn(act_name)(map_part1 + map_part2)

        logits = bn_dense_layer_v2(map_res, num_channel, True, 0., "logits",
                                   "linear", False, wd, keep_prob, is_train)
        if compressing:
            logits = decompress_seq_wrt_mask(logits, rev_d)
        return logits
Example #2
0
def logits_for_sketch_index(
    decoder_states,
    encoder_states,
    hn=None,
    wd=0.,
    keep_prob=1.0,
    is_train=None,
    compress_mask=None,
    scope=None,
):
    compressing = not isinstance(compress_mask, type(None))
    hn = hn or get_shape_list(decoder_states)[-1]
    with tf.variable_scope(scope or "logits_for_sketch_index"):
        if compressing:
            new_decoder_states, _, rev_d = compress_seq_wrt_mask(
                decoder_states, compress_mask)
        else:
            new_decoder_states = decoder_states
            rev_d = None
        with tf.variable_scope("projection"):
            encoder_states_map = bn_dense_layer_v2(encoder_states, hn, True,
                                                   0., "encoder_states_map",
                                                   "linear", False, wd,
                                                   keep_prob, is_train)
            decoder_states_map = bn_dense_layer_v2(new_decoder_states, hn,
                                                   True, 0.,
                                                   "decoder_states_map",
                                                   "linear", False, wd,
                                                   keep_prob, is_train)
        with tf.variable_scope("bi_linear"):
            bilinear_pre = bn_dense_layer_v2(decoder_states_map, hn, False, 0.,
                                             "bilinear_map", "linear", False,
                                             wd, keep_prob, is_train)
            logits = tf.matmul(bilinear_pre,
                               encoder_states_map,
                               transpose_b=True)  # bs,dsl,esl

            if compressing:
                logits = decompress_seq_wrt_mask(logits, rev_d)

            return logits
Example #3
0
def transformer_seq_decoder(dec_input_emb_mat,
                            decoder_ids,
                            encoder_states,
                            decoder_mask,
                            encoder_mask,
                            n_out_channel,
                            num_layers,
                            decoder_history_inputs=None,
                            hn=768,
                            head_num=12,
                            act_name="gelu",
                            wd=0.,
                            is_training=None,
                            keep_prob_dense=1.,
                            keep_prob_attn=1.,
                            keep_prob_res=1.,
                            scope=None):
    with tf.variable_scope(scope or "transformer_seq_decoder"):
        with tf.variable_scope("decoder_emb"):
            decoder_inputs = tf.nn.embedding_lookup(dec_input_emb_mat,
                                                    decoder_ids)  # bs,sl,hn

        with tf.variable_scope("decoder_recurrence"):
            dec_outputs, new_decoder_history_inputs = transformer_decoder(  # bs,sl,hn
                decoder_inputs,
                encoder_states,
                decoder_mask,
                encoder_mask,
                num_layers,
                decoder_history_inputs,
                hn,
                head_num,
                act_name,
                wd,
                is_training,
                keep_prob_dense,
                keep_prob_attn,
                keep_prob_res,
                scope="transformer_decoder")
            # prediction logits: two layer
            # pre_logits_seq2seq = bn_dense_layer_v2(
            #     dec_outputs, hn, True, 0., "pre_logits_seq2seq", act_name,
            #     False, 0., keep_prob_dense, is_training
            # )
            logits_seq2seq = bn_dense_layer_v2(  # bs,sl,
                dec_outputs, n_out_channel, True, 0., "logits_seq2seq",
                "linear", False, 0., keep_prob_dense, is_training)
            return dec_outputs, logits_seq2seq, new_decoder_history_inputs
Example #4
0
    def _build_network_seq_label_logits(self, encoder_states):
        wp_features = get_word_level_split(  # bs,sl,hn -> bs,asl,pl,hn
            encoder_states, self.input_pos_ids, self.wordpiece_idx, self.input_mask, self.asl, self.pl
        )

        all_token_features = s2t_self_attn(  # bs,asl,hn
            wp_features, self.wordpiece_mask, self.cfg['clf_act_name'], 'multi_dim',
            0., 1.-self.cfg['clf_dropout'], self.is_training, 'all_token_features',
        )
        # get seq_label_token_features  asl -> sll (asl-1)
        seq_label_token_features = mask_v3(  # remove the latest feature
            all_token_features[:, :-1], self.seq_label_mask, high_dim=True
        )

        with tf.variable_scope("output"):
            with tf.variable_scope("seq_labeling"):
                seq_label_logits = bn_dense_layer_v2(  # "O"  (NO PAD   for predicate no empty no pad
                    seq_label_token_features, 1 + (self.num_EO_labels-2) * (self.num_type_labels-2),
                    True, 0., "seq_labeling_logits", "linear", False,
                    0., 1. - self.cfg['clf_dropout'], self.is_training
                )
        return seq_label_logits
Example #5
0
def attn_post_proc(attn_res, inter_hn=None, wd=0., keep_prob=1., residual_keep_prob=1.,
                   is_train=None, activation='relu', sparse_opt=False,
                   scope=None, **kwargs):
    with tf.variable_scope(scope or "attn_res"):
        assert "mask" in kwargs
        if sparse_opt:
            x1, reverse_spec = masked_dense2sparse(attn_res, kwargs.get("mask"))
        else:
            x1 = attn_res

        y = bn_dense_layer_v2(
            x1, get_shape_list(attn_res)[-1], True, 0., "dense_layer", "linear", False,
            wd, keep_prob, is_train

        )
        x2 = residual_connection(x1, y, is_train, residual_keep_prob, "res_con")

        res = residual_connection_with_dense(
            x2, inter_hn or 4*get_shape_list(attn_res)[-1], True, 0., "residual_connection_with_dense",
            activation, False, wd, keep_prob, is_train, residual_keep_prob
        )
        if sparse_opt:
            res = masked_sparse2dense(res, reverse_spec)
        return res
Example #6
0
def compatibility_fn_lacacy(  # did not support arbitrary dim
        tensor_from, tensor_to, method='dot_product', scope=None, **kwargs):

    def _get_val_from_kwargs(key, default_val):
        if key in kwargs:
            return kwargs[key]
        else:
            return default_val

    with tf.variable_scope(scope or 'compatibility_fn.{}'.format(method)):
        shape_from = get_shape_list(tensor_from)
        ndim_from = len(shape_from)
        shape_to = get_shape_list(tensor_to)
        ndim_to = len(shape_to)

        assert (ndim_from == 2 or ndim_from == 3) and ndim_to == 3

        if ndim_from == 2:
            tensor_from = tf.expand_dims(tensor_from, 1)
            shape_from = get_shape_list(tensor_from)

        slf, slt = shape_from[1], shape_to[1]

        # hparams parsing
        hn = _get_val_from_kwargs('hn', shape_to[-1])
        wd = _get_val_from_kwargs('wd', 0.)
        keep_prob = _get_val_from_kwargs('keep_prob', 1.)
        is_training = _get_val_from_kwargs('is_training', None)
        activation = _get_val_from_kwargs('activation', 'relu')
        head_num = _get_val_from_kwargs('head_num', 12)

        seq_dim_to_remove = 1
        if method == 'dot_product':
            align_scores = tf.matmul(tensor_from, tensor_to, transpose_b=True)  # [bs,slf,hn]*[bs,slt,hn]=>bs,slf,slt
            align_scores = tf.expand_dims(align_scores, -1)  # [bs,slf,slt,1]
        elif method == 'additive':
            tensor_from_branch = bn_dense_layer_v2(
                tensor_from, hn, False, 0., 'tensor_from_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            tensor_to_branch = bn_dense_layer_v2(
                tensor_to, hn, True, 0., 'tensor_to_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            align_scores_pre = act_name2fn(activation)(tf.add(  # [bs,slf,slt,hn]
                tf.expand_dims(tensor_from_branch, 2),  # [bs,slf,1,hn]
                tf.expand_dims(tensor_to_branch, 1)  # [bs,1,slt,hn]
            ))
            align_scores = bn_dense_layer_v2(  # [bs,slf,slt,1]
                align_scores_pre, 1, True, 0., 'align_scores', 'linear', False,
                wd, keep_prob, is_training
            )
        elif method == 'multi_dim':
            logging.warning("No simplified multi-dim technique used in this function!")
            tensor_from_branch = bn_dense_layer_v2(
                tensor_from, hn, False, 0., 'tensor_from_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            tensor_to_branch = bn_dense_layer_v2(
                tensor_to, hn, True, 0., 'tensor_to_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            align_scores_pre = act_name2fn(activation)(tf.add(  # [bs,slf,slt,hn]
                tf.expand_dims(tensor_from_branch, 2),  # [bs,slf,1,hn]
                tf.expand_dims(tensor_to_branch, 1)  # bs,1,slt,hn
            ))
            align_scores = bn_dense_layer_v2(
                align_scores_pre, hn, True, 0., 'align_score', 'linear', False,
                wd, keep_prob, is_training
            )
        elif method == 'multi_head':
            seq_dim_to_remove = 2  # !!! because multi-head dim is on 2nd dim
            assert hn % head_num == 0
            head_dim = hn // head_num

            q_heads = bn_dense_layer_v2(
                tensor_from, head_dim, True, 0., 'q_heads',
                'linear', False, wd, keep_prob, is_training, dup_num=head_num
            )
            k_heads = bn_dense_layer_v2(
                tensor_to, head_dim, True, 0., 'k_heads',
                'linear', False, wd, keep_prob, is_training, dup_num=head_num
            )
            q_heads = split_head(q_heads, head_num)  # bs,hd_num,slf,hd_dim
            k_heads = split_head(k_heads, head_num)  # bs,hd_num,slt,hd_dim

            # alignment score
            align_scores = tf.matmul(q_heads, k_heads, transpose_b=True)  # [bs,hd_num,slf,slt]
            align_scores = align_scores / math.sqrt(1.*head_dim)  # [bs,hd_num,slf,slt]
        elif method == 'multi_dim_head':
            seq_dim_to_remove = 2  # !!! because multi-head dim is on 2nd dim
            assert hn % head_num == 0
            head_dim = hn // head_num

            q_heads = bn_dense_layer_v2(
                tensor_from, head_dim, True, 0., 'q_heads',
                'linear', False, wd, keep_prob, is_training, dup_num=head_num
            )
            k_heads = bn_dense_layer_v2(
                tensor_to, head_dim, True, 0., 'k_heads',
                'linear', False, wd, keep_prob, is_training, dup_num=head_num
            )
            q_heads = split_head(q_heads, head_num)  # bs,hd_num,slf,hd_dim
            k_heads = split_head(k_heads, head_num)  # bs,hd_num,slt,hd_dim

            # MLP
            q_heads_branch = bn_dense_layer_multi_head(
                q_heads, head_dim, False, 0., 'q_heads_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            k_heads_branch = bn_dense_layer_multi_head(
                k_heads, head_dim, True, 0., 'k_heads_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            align_scores_pre = act_name2fn(activation)(tf.add(  # [bs,head,slf,slt,dim]
                tf.expand_dims(q_heads_branch, 3),  # [bs,head,slf,1,dim]
                tf.expand_dims(k_heads_branch, 2)  # bs,head,1,slt,dim
            ))
            align_scores_heads = bn_dense_layer_multi_head(  # [bs,hd_num,slf,slt,hd_dim]
                align_scores_pre, head_dim, True, 0., 'align_scores_heads', 'linear', False,
                wd, keep_prob, is_training
            )
            align_scores = align_scores_heads  # [bs,hd_num,slf,slt,hd_dim]
            # align_scores = combine_head(align_scores_heads)
        elif method == 'bilinear':
            raise NotImplementedError
        else:
            raise AttributeError

        if ndim_from == 2:
            align_scores = tf.squeeze(align_scores, [seq_dim_to_remove])  #

        return align_scores
Example #7
0
def s2t_self_attn(
        tensor_input, tensor_mask, deep_act=None, method='multi_dim',
        wd=0., keep_prob=1., is_training=None,
        scope=None, **kwargs
):
    use_deep = isinstance(deep_act, str)  # use Two layers or Single layer for the alignment score
    with tf.variable_scope(scope or 's2t_self_attn_{}'.format(method)):
        tensor_shape = get_shape_list(tensor_input)
        hn = tensor_shape[-1]  # hidden state number

        if method == 'additive':
            align_scores = bn_dense_layer_v2(  # bs,sl,hn/1
                tensor_input, hn if use_deep else 1, True, 0., 'align_score_1', 'linear', False,
                wd, keep_prob, is_training
            )
            if use_deep:
                align_scores = bn_dense_layer_v2(  # bs,sl,1
                    act_name2fn(deep_act)(align_scores), 1, True, 0., 'align_score_2', 'linear', False,
                    wd, keep_prob, is_training
                )
        elif method == 'multi_dim':
            align_scores = bn_dense_layer_v2(  # bs,sl,hn
                tensor_input, hn, False, 0., 'align_score_1', 'linear', False,
                wd, keep_prob, is_training
            )
            if use_deep:
                align_scores = bn_dense_layer_v2(  # bs,sl,hn
                    act_name2fn(deep_act)(align_scores), hn, True, 0., 'align_score_2', 'linear', False,
                    wd, keep_prob, is_training
                )
        elif method == 'multi_dim_head':
            get_shape_list(tensor_input, expected_rank=3)  # the input should be rank-3
            assert 'head_num' in kwargs and isinstance(kwargs['head_num'], int)
            head_num = kwargs['head_num']
            assert hn % head_num == 0
            head_dim = hn // head_num

            tensor_input_heads = split_head(tensor_input, head_num)  # [bs,hd,sl,hd_dim]

            align_scores_heads = bn_dense_layer_multi_head(  # [bs,hd,sl,hd_dim]
                tensor_input_heads, head_dim, True, 0., 'align_scores_heads_1', 'linear', False,
                wd, keep_prob, is_training
            )
            if use_deep:
                align_scores_heads = bn_dense_layer_multi_head(  # [bs,hd,sl,hd_dim]
                    act_name2fn(deep_act)(align_scores_heads), head_dim,
                    True, 0., 'align_scores_heads_2', 'linear', False,
                    wd, keep_prob, is_training
                )
            align_scores = combine_head(align_scores_heads)  # [bs,sl,dim]
        else:
            raise AttributeError

        # attention procedure align_scores [bs,sl,1/dim]
        align_scores_masked = exp_mask_v3(align_scores, tensor_mask, multi_head=False, high_dim=True)  # bs,sl,hn
        attn_prob = tf.nn.softmax(align_scores_masked, axis=-2)  # bs,sl,hn

        if 'attn_keep_prob' in kwargs and isinstance(kwargs['attn_keep_prob'], float):
            attn_prob = dropout(attn_prob, kwargs['attn_keep_prob'], is_training)  # bs,sl,hn

        attn_res = tf.reduce_sum(  # [bs,sl,hn] -> [bs,dim]
            mask_v3(attn_prob*tensor_input, tensor_mask, high_dim=True), axis=-2
        )

        return attn_res  # [bs,hn]
Example #8
0
def compatibility_fn(tensor_from, tensor_to, method='dot_product', scope=None, **kwargs):
    def _get_val_from_kwargs(key, default_val):
        if key in kwargs:
            return kwargs[key]
        else:
            return default_val

    with tf.variable_scope(scope or 'compatibility_fn.{}'.format(method)):
        shape_from = get_shape_list(tensor_from)
        ndim_from = len(shape_from)
        shape_to = get_shape_list(tensor_to)
        ndim_to = len(shape_to)

        assert ndim_from == ndim_to or ndim_from+1 == ndim_to
        need_extra_dim = ndim_from+1 == ndim_to

        if need_extra_dim:
            tensor_from = tf.expand_dims(tensor_from, -2)
            shape_from = get_shape_list(tensor_from)

        slf, slt = shape_from[-2], shape_to[-2]

        # hparams parsing
        hn = _get_val_from_kwargs('hn', shape_to[-1])
        wd = _get_val_from_kwargs('wd', 0.)
        keep_prob = _get_val_from_kwargs('keep_prob', 1.)
        is_training = _get_val_from_kwargs('is_training', None)
        activation = _get_val_from_kwargs('activation', 'relu')
        head_num = _get_val_from_kwargs('head_num', 12)

        seq_dim_to_remove = -3
        if method == 'dot_product':
            align_scores = tf.matmul(tensor_from, tensor_to, transpose_b=True)  # [bs,slf,hn]*[bs,slt,hn]=>bs,slf,slt
            align_scores = tf.expand_dims(align_scores, -1)  # [bs,slf,slt,1]
        elif method == 'additive':
            tensor_from_branch = bn_dense_layer_v2(
                tensor_from, hn, False, 0., 'tensor_from_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            tensor_to_branch = bn_dense_layer_v2(
                tensor_to, hn, True, 0., 'tensor_to_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            align_scores_pre = act_name2fn(activation)(tf.add(  # [bs,slf,slt,hn]
                tf.expand_dims(tensor_from_branch, -2),  # [bs,slf,1,hn]
                tf.expand_dims(tensor_to_branch, -3)  # [bs,1,slt,hn]
            ))
            align_scores = bn_dense_layer_v2(  # [bs,slf,slt,1]
                align_scores_pre, 1, True, 0., 'align_scores', 'linear', False,
                wd, keep_prob, is_training
            )
        elif method == 'multi_dim':
            logging.warning("No simplified multi-dim technique used in this function!")
            tensor_from_branch = bn_dense_layer_v2(
                tensor_from, hn, False, 0., 'tensor_from_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            tensor_to_branch = bn_dense_layer_v2(
                tensor_to, hn, True, 0., 'tensor_to_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            align_scores_pre = act_name2fn(activation)(tf.add(  # [bs,slf,slt,hn]
                tf.expand_dims(tensor_from_branch, -2),  # [bs,slf,1,hn]
                tf.expand_dims(tensor_to_branch, -3)  # bs,1,slt,hn
            ))
            align_scores = bn_dense_layer_v2(  # [bs,slf,slt,hn]
                align_scores_pre, hn, True, 0., 'align_score', 'linear', False,
                wd, keep_prob, is_training
            )
        elif method == 'multi_head':
            seq_dim_to_remove = -2  # !!! because multi-head dim is on 2nd dim
            assert hn % head_num == 0
            head_dim = hn // head_num

            q_heads = bn_dense_layer_v2(
                tensor_from, head_dim, True, 0., 'q_heads',
                'linear', False, wd, keep_prob, is_training, dup_num=head_num
            )
            k_heads = bn_dense_layer_v2(
                tensor_to, head_dim, True, 0., 'k_heads',
                'linear', False, wd, keep_prob, is_training, dup_num=head_num
            )
            q_heads = split_head(q_heads, head_num)  # bs,hd_num,slf,hd_dim
            k_heads = split_head(k_heads, head_num)  # bs,hd_num,slt,hd_dim

            # alignment score
            align_scores = tf.matmul(q_heads, k_heads, transpose_b=True)  # [bs,hd_num,slf,slt]
            align_scores = align_scores / math.sqrt(1.*head_dim)  # [bs,hd_num,slf,slt]
        elif method in ['multi_head_bilinear', 'multi_head_bilinear_shared', 'multi_head_only', 'multi_head_linear']:
            seq_dim_to_remove = -2  # !!! because multi-head dim is on 2nd dim
            assert hn % head_num == 0
            head_dim = hn // head_num

            q_heads = bn_dense_layer_v2(
                tensor_from, head_dim, True, 0., 'q_heads',
                kwargs.get("activation") or activation,
                False, wd, keep_prob, is_training, dup_num=head_num
            )
            k_heads = bn_dense_layer_v2(
                tensor_to, head_dim, True, 0., 'k_heads',
                kwargs.get("activation") or activation,
                False, wd, keep_prob, is_training, dup_num=head_num
            )
            q_heads = split_head(q_heads, head_num)  # bs,hd_num,slf,hd_dim
            k_heads = split_head(k_heads, head_num)  # bs,hd_num,slt,hd_dim

            # alignment score: using biliear rather than dot product
            # align_scores = tf.matmul(q_heads, k_heads, transpose_b=True)  # [bs,hd_num,slf,slt]
            # align_scores = align_scores / math.sqrt(1. * head_dim)  # [bs,hd_num,slf,slt]
            with tf.variable_scope("bilinear"):
                if method == "multi_head_bilinear":
                    k_heads_map = bn_dense_layer_multi_head(
                        k_heads, head_dim, False, 0., 'k_heads_map', 'linear', False, wd, keep_prob, is_training)
                elif method == "multi_head_bilinear_shared":
                    k_heads_map = bn_dense_layer_v2(
                        k_heads, head_dim, False, 0., 'k_heads_map', 'linear', False, wd, keep_prob, is_training)
                elif method == "multi_head_only":
                    pass
                elif method == "multi_head_linear":
                    k_heads_map = bn_dense_layer_v2(
                        k_heads, head_dim, False, 0., 'k_heads_map', 'linear', False, wd, keep_prob, is_training)
                    q_heads_map = bn_dense_layer_v2(
                        q_heads, head_dim, False, 0., 'q_heads_map', 'linear', False, wd, keep_prob, is_training)
                else:
                    raise AttributeError
                align_scores = tf.matmul(q_heads, k_heads, transpose_b=True)

                log_specific_params()


        elif method == 'multi_dim_head':
            assert hn % head_num == 0
            head_dim = hn // head_num

            q_heads = bn_dense_layer_v2(
                tensor_from, head_dim, True, 0., 'q_heads',
                'linear', False, wd, keep_prob, is_training, dup_num=head_num
            )
            k_heads = bn_dense_layer_v2(
                tensor_to, head_dim, True, 0., 'k_heads',
                'linear', False, wd, keep_prob, is_training, dup_num=head_num
            )
            q_heads = split_head(q_heads, head_num)  # bs,hd_num,slf,hd_dim
            k_heads = split_head(k_heads, head_num)  # bs,hd_num,slt,hd_dim

            # MLP
            q_heads_branch = bn_dense_layer_multi_head(
                q_heads, head_dim, False, 0., 'q_heads_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            k_heads_branch = bn_dense_layer_multi_head(
                k_heads, head_dim, True, 0., 'k_heads_branch', 'linear', False,
                wd, keep_prob, is_training
            )
            align_scores_pre = act_name2fn(activation)(tf.add(  # [bs,head,slf,slt,dim]
                tf.expand_dims(q_heads_branch, -2),  # [bs,head,slf,1,dim]
                tf.expand_dims(k_heads_branch, -3)  # bs,head,1,slt,dim
            ))
            align_scores_heads = bn_dense_layer_multi_head(  # [bs,hd_num,slf,slt,hd_dim]
                align_scores_pre, head_dim, True, 0., 'align_scores_heads', 'linear', False,
                wd, keep_prob, is_training
            )
            align_scores = align_scores_heads  # [bs,hd_num,slf,slt,hd_dim]
        elif method == 'bilinear':
            raise NotImplementedError
        else:
            raise AttributeError

        if need_extra_dim:
            align_scores = tf.squeeze(align_scores, [seq_dim_to_remove])  #

        return align_scores
Example #9
0
def multihead_attention_decoder(
    tensor_from,
    tensor_to,
    mask_to,
    mask_direction=None,  # [bs,slf,slt]
    act_name="relu",
    hn=768,
    head_num=12,
    wd=0.,
    is_training=None,
    keep_prob_dense=1.,
    keep_prob_attn=1.,
    tensor_to_prev=None,
    mask_prev_to=None,
    scope=None,
):
    head_dim = hn // head_num
    with tf.variable_scope(scope or "multihead_attention_decoder"):
        # if not isinstance(tensor_to_prev, type(None)):  # to print the shape
        #     tensor_from = tf.Print(tensor_from, [
        #         tf.shape(tensor_from), tf.shape(tensor_to),  tf.shape(mask_to),  tf.shape(tensor_to_prev)])

        if isinstance(tensor_to_prev, type(None)):
            tensor_to_all = tensor_to  # bs,sl,hn
            mask_to_all = mask_to  # bs,sl
        else:
            tensor_to_all = tf.concat([tensor_to_prev, tensor_to],
                                      -2)  # bs,psl+1,hn
            if mask_prev_to is None:
                mask_prev_to = tf.cast(
                    tf.ones(get_shape_list(tensor_to_prev, 3)[:2], tf.int32),
                    tf.bool)  # bs,psl
            mask_to_all = tf.concat([mask_prev_to, mask_to], -1)  # bs,psl+1

        attn_scores = compatibility_fn(
            tensor_from,
            tensor_to_all,
            method="multi_head",
            head_num=head_num,
            hn=hn,
            wd=wd,
            is_training=is_training,
            keep_prob=keep_prob_dense,
        )  # [bs,hd_num,slf,slt]
        v_heads = bn_dense_layer_v2(  # bs,slt,hd_dim * hd_num
            tensor_to_all,
            head_dim,
            True,
            0.,
            'v_heads',
            'linear',
            False,
            wd,
            keep_prob_dense,
            is_training,
            dup_num=head_num)
        v_heads = split_head(v_heads, head_num)  # # bs,hd_num,slt,hd_dim

        # mask the self-attention scores
        attn_scores_mask = tf.expand_dims(mask_to_all, 1)  # bs,1,tsl
        if (not isinstance(mask_direction, type(None))) and isinstance(
                tensor_to_prev, type(None)):
            attn_scores_mask = tf.logical_and(attn_scores_mask,
                                              mask_direction)  # bs,tsl,tsl
        attn_scores_masked = exp_mask_v3(
            attn_scores, attn_scores_mask,
            multi_head=True)  # [bs,hd_num,slf,slt]
        attn_prob = tf.nn.softmax(attn_scores_masked)
        attn_prob = dropout(attn_prob, keep_prob_attn,
                            is_training)  # [bs,hd_num,slf,slt]

        v_heads_etd = tf.expand_dims(v_heads, 2)  # bs,hd_num,1,slt,hd_dim
        attn_prob_etd = tf.expand_dims(attn_prob, -1)  # bs,hd_num,slf,slt,1

        attn_res = tf.reduce_sum(v_heads_etd * attn_prob_etd,
                                 3)  # bs,hd_num,slf,hd_dim
        out_prev = combine_head(attn_res)  # bs,fsl,hn

        # if mask_direction is not None and tensor_to_prev is None:
        #     attn_scores = exp_mask_v3(attn_scores, mask_direction, multi_head=True)  # [bs,hd_num,slf,slt]
        # attn_scores = dropout(attn_scores, keep_prob_attn, is_training)
        #
        # attn_res = softsel( # [bs,hd_num,slf,dhn]
        #     v_heads, attn_scores, mask_to_all,
        #     mask_add_head_dim_for_scores=True,
        #     input_add_multi_head_dim=False,
        #     score_add_hn_dim=True,
        #     axis=3)
        # out_prev = combine_head(attn_res)
        # dense layer
        out = bn_dense_layer_v2(out_prev, hn, True, 0., "output_transformer",
                                act_name, False, wd, keep_prob_dense,
                                is_training)
        return out