Exemplo n.º 1
0
def logits_for_sketch_prediction(decoder_states,
                                 cls_state,
                                 num_channel,
                                 hn=None,
                                 act_name="relu",
                                 wd=0.,
                                 keep_prob=1.0,
                                 is_train=None,
                                 compress_mask=None,
                                 scope=None):
    compressing = not isinstance(compress_mask, type(None))
    hn = hn or get_shape_list(decoder_states)[-1]
    with tf.variable_scope(scope or "logits_for_sketch_index"):
        if compressing:
            new_decoder_states, _, rev_d = compress_seq_wrt_mask(
                decoder_states, compress_mask)
        else:
            new_decoder_states = decoder_states
            rev_d = None
        map_part1 = bn_dense_layer_v2(new_decoder_states, hn, True, 0.,
                                      "map_part1", "linear", False, wd,
                                      keep_prob, is_train)
        map_part2_pre = bn_dense_layer_v2(cls_state, hn, False, 0.,
                                          "map_part2_pre", "linear", False, wd,
                                          keep_prob, is_train)
        map_part2 = tf.tile(tf.expand_dims(map_part2_pre, axis=1),
                            [1, get_shape_list(map_part1)[1], 1])
        map_res = act_name2fn(act_name)(map_part1 + map_part2)

        logits = bn_dense_layer_v2(map_res, num_channel, True, 0., "logits",
                                   "linear", False, wd, keep_prob, is_train)
        if compressing:
            logits = decompress_seq_wrt_mask(logits, rev_d)
        return logits
Exemplo n.º 2
0
def get_word_level_split(params, input_pos_ids, wordpiece_idx, input_mask, sll,
                         pl):
    # bs,sl,pl
    bs, sl = get_shape_list(input_pos_ids)

    higher_dim = len(get_shape_list(params)) > 2
    extra_dims = get_shape_list(params)[2:] if higher_dim else []

    # tf.tile(tf.expand_dims(tf.expand_dims(tf.range(bs), 1), 2), [1, sll, pl])
    bs_idxs = tf.tile(tf.expand_dims(tf.range(bs), 1), [1, sl])

    data_coord = tf.stack([bs_idxs, input_pos_ids, wordpiece_idx],
                          -1)  # [bs, sl, 3]
    # mask input_pos_ids and wordpiece_idx for -1
    mask_reversed_int = tf.cast(tf.logical_not(input_mask), tf.int32)
    data_coord = mask_v3(data_coord, input_mask, high_dim=True) + tf.stack(
        [
            mask_reversed_int * bs,
            mask_reversed_int * sll,
            mask_reversed_int * pl,
        ],
        axis=-1)

    # params's dtype check
    is_bool = (params.dtype == tf.bool)
    outputs = tf.scatter_nd(
        indices=data_coord,  # [bs, sl, 3]
        updates=params
        if not is_bool else tf.cast(params, tf.int32),  # [bs,sl]
        shape=[bs + 1, sll + 1, pl + 1] + extra_dims)
    if is_bool:
        outputs = tf.cast(outputs, tf.bool)

    outputs = outputs[:-1, :-1, :-1]
    return outputs
Exemplo n.º 3
0
def cross_attn_mask_generation(from_mask, to_mask, mutual=True, head_num=None, name=None):
    """

    :param from_mask: 2-D Tensor, [bs,slf]
    :param to_mask: 2-D Tensor, [bs,slt]
    :param mutual:
    :param head_num
    :param name:
    :return: 3D Tensor
    """
    with tf.name_scope(name or 'attention_mask_generation'):
        bs, slf = get_shape_list(from_mask, 2)[:2]
        slt = get_shape_list(to_mask, 2)[1]

        if mutual:
            res_mask = tf.cast(  # [bs,slf,slt]
                tf.expand_dims(tf.cast(from_mask, tf.int32), 2) * tf.expand_dims(tf.cast(to_mask, tf.int32), 1),
                tf.bool
            )
        else:
            res_mask = tf.tile(tf.expand_dims(to_mask, 1), [1, slf, 1])  # [bs,slt] -> [bs,slf,slt]

        if isinstance(head_num, int):
            res_mask = tf.expand_dims(res_mask, 1)
            tile_multiples = [1] * len(get_shape_list(res_mask))
            tile_multiples[1] = head_num
            res_mask = tf.tile(res_mask, tile_multiples)

        return res_mask
Exemplo n.º 4
0
    def _build_network_all_sketch_logits(
            self, decoder_states, encoder_states_for_decoder,
            encoder_mask, cls_state_predicate, cls_state_type,
            use_mask=False
    ):
        bs = get_shape_list(decoder_states)[0]
        if use_mask:
            entity_mask = tf.not_equal(self.sketch_entity, -1)
            predicate_mask = tf.not_equal(self.sketch_predicate, 0)
            type_mask = tf.not_equal(self.sketch_type, 0)
            num_mask = tf.not_equal(self.sketch_num, -1)
        else:
            entity_mask = None
            predicate_mask = None
            type_mask = None
            num_mask = None
        # bs,sl -----modify the last token to False
        encoder_wo_cls = tf.concat([  # [bs,sl]
            encoder_mask[:, 1:],  # [bs,sl-1]
            tf.cast(tf.zeros([get_shape_list(encoder_mask)[0], 1], tf.int32), tf.bool)  # [bs, 1]
        ], -1)

        logits_sketch_entity_pre = logits_for_sketch_index(  # bs,dsl,esl
            decoder_states, encoder_states_for_decoder, self.cfg["hn"], 0., 1 - self.cfg["clf_dropout"],
            self.is_training, compress_mask=entity_mask, scope="logits_sketch_entity_pre"
        )
        logits_sketch_entity = mask_v3(
            logits_sketch_entity_pre, encoder_wo_cls, multi_head=True, name="logits_sketch_entity")

        logits_sketch_predicate_pre = logits_for_sketch_prediction(
            decoder_states, cls_state_predicate, self.num_predicate_labels - 3,  self.cfg["hn"],
            self.cfg["clf_act_name"], 0., 1 - self.cfg["clf_dropout"], self.is_training,
            compress_mask=predicate_mask,
            scope="logits_sketch_predicate"
        )
        logits_sketch_predicate = tf.concat([
            tf.ones([bs, get_shape_list(logits_sketch_predicate_pre)[1], 3], tf.float32) * VERY_NEGATIVE_NUMBER,
            logits_sketch_predicate_pre,
        ], axis=-1)

        logits_sketch_type_pre = logits_for_sketch_prediction(
            decoder_states, cls_state_type, self.num_type_labels - 3, self.cfg["hn"],
            self.cfg["clf_act_name"], 0., 1 - self.cfg["clf_dropout"], self.is_training,
            compress_mask=type_mask,
            scope="logits_sketch_type"
        )
        logits_sketch_type = tf.concat([
            tf.ones([bs, get_shape_list(logits_sketch_type_pre)[1], 3], tf.float32) * VERY_NEGATIVE_NUMBER,
            logits_sketch_type_pre,
        ], axis=-1)

        logits_sketch_num_pre = logits_for_sketch_index(
            decoder_states, encoder_states_for_decoder, self.cfg["hn"], 0., 1 - self.cfg["clf_dropout"],
            self.is_training,  compress_mask=num_mask, scope="logits_sketch_num_pre"
        )
        logits_sketch_num = mask_v3(
            logits_sketch_num_pre, encoder_wo_cls, multi_head=True, name="logits_sketch_num")

        return logits_sketch_entity, logits_sketch_predicate, logits_sketch_type, logits_sketch_num
Exemplo n.º 5
0
Arquivo: nn.py Projeto: mukundhs/Code
def masked_sparse2dense(input_tensor, reverse_spec, name=None):
    org_input_mask = reverse_spec['org_input_mask']
    org_coords = reverse_spec['org_coords']

    with tf.variable_scope(name or "masked_sparse2dense"):
        hn = get_shape_list(input_tensor)[-1]
        org_shape = get_shape_list(org_input_mask)
        org_shape.append(hn)
        return tf.scatter_nd(org_coords, input_tensor, org_shape)  # [xx,hn]
Exemplo n.º 6
0
def get_key_indices(tensor_input, special_token_list):
    # tensor_input 2
    get_shape_list(tensor_input, 2)
    out_indices_list = []
    for sp_token in special_token_list:
        out_indices_list.append(
            tf.cast(
                tf.argmax(tf.cast(tf.equal(tensor_input, sp_token), tf.int32), 1),
                tf.int32)
        )
    return out_indices_list
Exemplo n.º 7
0
def decompress_2nd_dim_from_batch(input_tensor,
                                  reverse_spec,
                                  name=None):  # [nbs, 1,...] -> [bs,2d,...]
    with tf.name_scope(name or "decompress_2nd_dim_from_batch"):
        input_tensor_squeeze = tf.squeeze(input_tensor, 1)
        remain_shape = get_shape_list(input_tensor_squeeze)[1:]
        org_coords = reverse_spec["org_coords"]
        org_input_mask = reverse_spec["org_input_mask"]
        bs, sd = get_shape_list(org_input_mask)
        return tf.scatter_nd(
            org_coords, input_tensor_squeeze,
            [tf.to_int64(elem) for elem in [bs, sd] + remain_shape])
Exemplo n.º 8
0
def decompress_seq_wrt_mask(tensor_input, reverse_dict):
    bs, tgt_len, hn = get_shape_list(tensor_input)
    src_len = get_shape_list(reverse_dict["src_mask"])[1]

    padded_tensor = tf.scatter_nd(reverse_dict["coord"], tensor_input,
                                  [bs, src_len + 1, hn])
    out_tensor = padded_tensor[:, :-1]  # bs,src_len,hn

    masked_out_tensor = mask_v3(out_tensor,
                                reverse_dict["src_mask"],
                                high_dim=True)
    return masked_out_tensor
Exemplo n.º 9
0
Arquivo: nn.py Projeto: mukundhs/Code
def smoothed_softmax_cross_entropy_with_logits(**kwargs):
    logits = kwargs.get("logits")
    labels = kwargs.get("labels")
    label_smoothing = kwargs.get("label_smoothing") or 0.0
    normalize = kwargs.get("normalize")

    if logits is None or labels is None:
        raise ValueError("Both logits and labels must be provided")

    if not label_smoothing:
        ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                            labels=labels)
        return ce

    # adaptive for any rank
    vocab_size = get_shape_list(logits)[-1]

    n = tf.to_float(vocab_size - 1)
    p = 1.0 - label_smoothing
    q = label_smoothing / n

    soft_targets = tf.one_hot(tf.cast(labels, tf.int32),
                              depth=vocab_size,
                              on_value=p,
                              off_value=q)

    xentropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits,
                                                          labels=soft_targets)

    if not normalize:
        return xentropy

    normalizing = -(p * tf.log(p) + n * q * tf.log(q + 1e-20))

    return xentropy - normalizing
Exemplo n.º 10
0
def get_slice(tensor_input, start_idxs, end_idxs):
    # 1. the size of 1st dim of tensor_input, start_idxs, end_idxs must be equal
    # 2. the idxs is the 2nd dim of tensor input
    # 3. output: 1. a output tensor, 2. a mask
    tensor_shape = get_shape_list(tensor_input)
    bs = tensor_shape[0]
    sl = tensor_shape[1]
    extra_dims = tensor_shape[2:] if len(tensor_shape) > 2 else []
    lens = end_idxs - start_idxs - 1
    max_len = tf.reduce_max(lens)

    # target bool indicator
    indices_input = tf.tile(tf.expand_dims(tf.range(sl, dtype=tf.int32), 0), [bs, 1])  # bs, sl
    indices_new = indices_input - tf.expand_dims(start_idxs, 1) - 1  # bs, sl
    tgt_bool_indicator = tf.logical_and(
        tf.greater(indices_input, tf.expand_dims(start_idxs, 1)),
        tf.less(indices_input, tf.expand_dims(end_idxs, 1)),
    )

    coord_in_input = tf.where(tgt_bool_indicator)  # [n_true, 2]
    two_d_indices_new = tf.stack(  # bs,sl,2
        values=[
            tf.tile(tf.expand_dims(tf.range(bs, dtype=tf.int32), 1), [1, sl]),
            indices_new,
        ], axis=-1
    )

    coord_in_output = tf.gather_nd(two_d_indices_new, coord_in_input)  # [n_true, 2]
    gathered_tensor_input = tf.gather_nd(tensor_input, coord_in_input)  # [n_true]+extra_dims

    tensor_output = tf.scatter_nd(coord_in_output, gathered_tensor_input, [bs, max_len] + extra_dims)
    mask_output = generate_mask_based_on_lens(lens, max_len)
    return tensor_output, mask_output
Exemplo n.º 11
0
def top_k_to_coordinate(top_k_vec,
                        prob_tensor=None,
                        logits=None,
                        dim=None,
                        name=None):
    if isinstance(prob_tensor, type(None)):
        prob_tensor = tf.nn.softmax(logits, axis=-1)[..., dim]

    with tf.name_scope(name or "top_k_to_coordinate"):
        bs, sll = get_shape_list(prob_tensor, expected_rank=2)
        sorted_tensor = tf.contrib.framework.sort(
            prob_tensor, axis=-1, direction='DESCENDING')  # bs,sll

        padded_sorted_tensor = tf.concat(
            [sorted_tensor, -tf.ones([bs, 1], sorted_tensor.dtype)],
            axis=-1)  # [bs,sll+1]
        k_th_scores_indices = tf.stack(  # [bs,2]
            [
                tf.range(bs, dtype=tf.int32),
                top_k_vec,
            ], axis=-1)
        k_th_scores = tf.expand_dims(tf.gather_nd(padded_sorted_tensor,
                                                  k_th_scores_indices),
                                     axis=-1)  # [bs,1]

        mask_mat = tf.greater(prob_tensor, k_th_scores)  # [bs,sll]
        return mask_matrix_to_coordinate(mask_mat)
Exemplo n.º 12
0
def mask_generation(rep_mask, head_num, use_direction, attn_self, name=None):  # this mask is for self-attention
    with tf.name_scope(name or 'mask_generation'):
        rep_shape = get_shape_list(rep_mask, 2)
        bs, sl = rep_shape
        # regular mask
        rep_mask_epd1 = tf.expand_dims(rep_mask, 1)  # bs,1,sl
        rep_mask_epd2 = tf.expand_dims(rep_mask, 2)  # bs,sl,1
        rep_mask_mat = tf.logical_and(rep_mask_epd1, rep_mask_epd2)  # bs,sl,sl

        # position mask
        sl_indices = tf.range(sl, dtype=tf.int32)
        sl_col, sl_row = tf.meshgrid(sl_indices, sl_indices)

        if use_direction:
            comp_func = tf.greater_equal if attn_self else tf.greater
            fw_mask = comp_func(sl_row, sl_col)  # sl,sl
            bw_mask = comp_func(sl_col, sl_row)  # sl,sl
            direct_mask = tf.stack([fw_mask, bw_mask], 0)  # 2,sl,sl
            direct_mask = tf.reshape(  # num,sl,sl
                tf.tile(tf.expand_dims(direct_mask, 1), [1, int(head_num / 2), 1, 1]),  # 2,4,sl,sl
                [head_num, sl, sl])
        else:
            if not attn_self:
                direct_mask = tf.tile(tf.expand_dims(tf.not_equal(sl_row, sl_col), 0), [head_num, 1, 1])  # n,sl,sl
            else:
                raise(ValueError, "A attention overself must be avoided without fw/bw information")

        final_mask = tf.logical_and(  # bs,num,sl,sl
            tf.expand_dims(rep_mask_mat, 1),
            tf.expand_dims(direct_mask, 0))
        return final_mask
Exemplo n.º 13
0
def generate_label_mask(input_pos_ids, input_mask, sll):
    input_pos_ids = mask_v3(input_pos_ids + 1, input_mask)

    bs, sl = get_shape_list(input_pos_ids)

    sll_idxs = tf.tile(tf.expand_dims(tf.range(sll, dtype=tf.int32), 0),
                       [bs, 1])  # bs,sl
    max_idxs = tf.reduce_max(input_pos_ids, axis=-1, keepdims=True)  # [bs,1]

    return tf.less(sll_idxs, max_idxs)
Exemplo n.º 14
0
def number_to_index(num_vec,
                    name=None):  # [3, 2, 0, 2, 1] -> [0, 0, 0, 1, 1, 3, 3, 4]
    with tf.name_scope(name or "number_to_index"):
        vec_len = get_shape_list(num_vec)[0]
        max_num = tf.reduce_max(num_vec)  # []

        idx_mat = tf.tile(tf.expand_dims(tf.range(vec_len, dtype=tf.int32),
                                         -1), [1, max_num])  # [len,num]
        num_mask = generate_mask_based_on_lens(num_vec, max_num)  # [len,num]
        coords = tf.where(num_mask)  # [new,2]
        return tf.gather_nd(idx_mat, coords)  # [new]
Exemplo n.º 15
0
def extend_batch_for_2nd_dim_compression(input_tensor,
                                         reverse_spec=None,
                                         num_vec=None,
                                         name=None):  # [bs,...] -> [nbs,...]
    with tf.name_scope(name or "extend_batch_for_2nd_dim_compression"):

        if reverse_spec is not None:
            org_input_mask = reverse_spec["org_input_mask"]
            org_coords = reverse_spec["org_coords"]
            vec_len = get_shape_list(org_input_mask)[0]
        else:
            max_num = tf.reduce_max(num_vec)  # []
            org_input_mask = generate_mask_based_on_lens(num_vec, max_num)
            org_coords = tf.where(org_input_mask)
            vec_len = get_shape_list(num_vec)[0]
        max_num = get_shape_list(org_input_mask)[-1]
        idx_mat = tf.tile(
            tf.expand_dims(tf.range(vec_len, dtype=tf.int32), -1),
            [1, max_num])
        num_indices = tf.expand_dims(tf.gather_nd(idx_mat, org_coords),
                                     -1)  # [nbs,1]
        return tf.gather_nd(input_tensor, num_indices)  # [nbs,...]
Exemplo n.º 16
0
Arquivo: nn.py Projeto: mukundhs/Code
def combine_head(inp_tensor, name=None):
    with tf.name_scope(name or 'combine_head'):
        # [bs,hd_num,sl,hd_dim] as an example
        inp_shape = get_shape_list(inp_tensor)  # [4] for [bs,hd_num,sl,hd_dim]

        # get hn from head_num * head_dim
        assert isinstance(inp_shape[1], int) and isinstance(inp_shape[-1], int)
        hn = inp_shape[1] * inp_shape[-1]

        # move head dim to -1
        new_perm = list(range(len(inp_shape)))  # [0,1,2,3]
        head_dim = new_perm.pop(1)  # [0,2,3]
        new_perm.insert(-1, head_dim)  # [0,2,1,3]

        inp_tensor_new_perm = tf.transpose(inp_tensor,
                                           new_perm)  # [bs,sl,hd_num,hd_dim]
        # get new shape
        new_shape = get_shape_list(inp_tensor_new_perm)[:-2] + [
            hn
        ]  # [3] for [bs,sl,hn]
        # return reshaped tensor
        return tf.reshape(inp_tensor_new_perm, new_shape)  # [bs,sl,hn]
Exemplo n.º 17
0
def compress_2nd_dim_to_batch(input_tensor,
                              num_vec,
                              name=None):  # [bs,sd,...] -> [nbs,...]
    with tf.name_scope(name or "compress_2nd_dim_to_batch"):
        bs, sd = get_shape_list(input_tensor)[:2]
        num_mask = generate_mask_based_on_lens(num_vec, sd)  # [bs,sd]
        coords = tf.where(num_mask)  # [nbs,2]
        reverse_spec = {
            "org_coords": coords,
            "org_input_mask": num_mask,
        }
        out_tensor = tf.gather_nd(input_tensor, coords)  # [nbs,...]
        out_tensor = tf.expand_dims(out_tensor, 1)
        return out_tensor, reverse_spec
Exemplo n.º 18
0
def attn_post_proc(attn_res, inter_hn=None, wd=0., keep_prob=1., residual_keep_prob=1.,
                   is_train=None, activation='relu', sparse_opt=False,
                   scope=None, **kwargs):
    with tf.variable_scope(scope or "attn_res"):
        assert "mask" in kwargs
        if sparse_opt:
            x1, reverse_spec = masked_dense2sparse(attn_res, kwargs.get("mask"))
        else:
            x1 = attn_res

        y = bn_dense_layer_v2(
            x1, get_shape_list(attn_res)[-1], True, 0., "dense_layer", "linear", False,
            wd, keep_prob, is_train

        )
        x2 = residual_connection(x1, y, is_train, residual_keep_prob, "res_con")

        res = residual_connection_with_dense(
            x2, inter_hn or 4*get_shape_list(attn_res)[-1], True, 0., "residual_connection_with_dense",
            activation, False, wd, keep_prob, is_train, residual_keep_prob
        )
        if sparse_opt:
            res = masked_sparse2dense(res, reverse_spec)
        return res
Exemplo n.º 19
0
def compress_seq_wrt_mask(tensor_input, tensor_mask):

    bs, sl, hn = get_shape_list(tensor_input)

    seq_lens = tf.reduce_sum(tf.cast(tensor_mask, tf.int32), -1)  # sl
    max_len = tf.reduce_max(seq_lens)  # []
    new_mask = generate_mask_based_on_lens(seq_lens, max_len)

    # ======> to ensure every batch get same elem via padding
    pad_lens = max_len - seq_lens
    max_pad_len = tf.reduce_max(pad_lens)
    pad_mask = generate_mask_based_on_lens(pad_lens, max_pad_len)

    padded_tensor_mask = tf.concat([tensor_mask, pad_mask],
                                   axis=-1)  # bs,sl+max_pad_len
    # new coord
    bs_idxs = generate_seq_idxs(bs, sl + max_pad_len,
                                transpose=True)  # bs,sl+max_pad_len
    sl_idxs = tf.concat(  # bs,sl+max_pad_len
        [
            generate_seq_idxs(bs, sl, transpose=False),  # bs,sl
            -tf.ones([bs, max_pad_len], tf.int32)  # bs, max_pad_len
        ],
        axis=-1)
    data_coord_map = tf.stack([bs_idxs, sl_idxs],
                              axis=-1)  # bs,sl+max_pad_len,2

    padded_coord = tf.where(padded_tensor_mask)  # bs*max_len,2

    mapped_padded_coord_rsp = tf.gather_nd(data_coord_map,
                                           padded_coord)  # bs*max_len,2
    mapped_padded_coord = tf.reshape(mapped_padded_coord_rsp,
                                     [bs, max_len, 2])  # bs,max_len,2

    gathered_data = tf.gather_nd(tensor_input,
                                 mapped_padded_coord)  # bs,max_len,hn
    masked_gathered_data = mask_v3(gathered_data, new_mask, high_dim=True)

    reverse_dict = {
        "src_mask": tensor_mask,
        "tgt_mask": new_mask,
        "coord": mapped_padded_coord,  # bs,max_len,2
    }

    return masked_gathered_data, new_mask, reverse_dict
Exemplo n.º 20
0
Arquivo: nn.py Projeto: mukundhs/Code
def split_head(inp_tensor, head_num, name=None):
    with tf.name_scope(name or 'split_head'):
        # [bs,sl,num] as an example
        inp_shape = get_shape_list(inp_tensor)  # [3] for [bs,sl,hn]
        # head params
        hn = inp_shape[-1]
        assert hn % head_num == 0
        head_dim = hn // head_num
        new_input_shape = inp_shape[:-1] + [head_num, head_dim
                                            ]  # [4] for [bs,sl,hd_num,hd_dim]

        new_perm = list(range(len(new_input_shape)))  # [0,1,2,3]
        head_dim = new_perm.pop(-2)  # [0,1,3]
        new_perm.insert(1, head_dim)  # [0,2,1,3]

        inp_tensor_hd = tf.reshape(inp_tensor,
                                   new_input_shape)  # [bs,sl,hd_num,hd_dim]
        return tf.transpose(inp_tensor_hd, new_perm)  # [bs,hd_num,sl,hd_dim]
Exemplo n.º 21
0
def transform_pos_ids_to_wordpiece_idx(input_pos_ids, input_mask, sll):
    # 0 0 1 1 1 2 2 2 2 3 3 0 0 0 0 0  # bs,sl
    #
    bs, sl = get_shape_list(input_pos_ids)
    diff_pos = mask_v3(  # bs,sl
        input_pos_ids -
        tf.concat([tf.zeros([bs, 1], dtype=tf.int32), input_pos_ids[:, :-1]],
                  axis=1), input_mask)

    sl_idxs = tf.tile(tf.expand_dims(tf.range(sl, dtype=tf.int32), 0),
                      [bs, 1])  # bs,sl
    word_start_index = diff_pos * sl_idxs  # bs, sl
    # remove all 0 value
    slx_s = tf.reduce_sum(diff_pos,
                          axis=-1)  # the number of non-zero for each example
    slx = tf.reduce_max(slx_s)  #
    sly_s = slx - slx_s  # the number of non-zero for padding
    sly = tf.reduce_max(sly_s)  #
    padding_seq = tf.cast(generate_mask_based_on_lens(sly_s, sly), tf.int32)
    valid_data_mask = generate_mask_based_on_lens(slx_s, slx)  # bs, slx

    padded_word_start_index = tf.concat([word_start_index, padding_seq],
                                        axis=-1)  # bs,sl+sly

    data_coord = tf.reshape(  # bs, slx
        tf.where(tf.cast(padded_word_start_index, tf.bool)),  # bs*slx,2
        [bs, slx, 2])

    word_start = tf.concat(  # bs, sll
        [
            tf.zeros([bs, 1], dtype=tf.int32),
            mask_v3(tf.gather_nd(padded_word_start_index, data_coord),
                    valid_data_mask),  # bs,slx
            tf.zeros([bs, sll - slx - 1], dtype=tf.int32)
        ],
        axis=1)

    bs_idxs = generate_seq_idxs(bs, sl, transpose=True)  # bs,sl
    base_coord = tf.stack([bs_idxs, input_pos_ids], axis=-1)  # bs,sl,2
    base_value = tf.gather_nd(word_start, base_coord)  # bs,sl

    # finally
    outputs = mask_v3(sl_idxs - base_value, input_mask)  # bs,sl
    return outputs
Exemplo n.º 22
0
def logits_for_sketch_index(
    decoder_states,
    encoder_states,
    hn=None,
    wd=0.,
    keep_prob=1.0,
    is_train=None,
    compress_mask=None,
    scope=None,
):
    compressing = not isinstance(compress_mask, type(None))
    hn = hn or get_shape_list(decoder_states)[-1]
    with tf.variable_scope(scope or "logits_for_sketch_index"):
        if compressing:
            new_decoder_states, _, rev_d = compress_seq_wrt_mask(
                decoder_states, compress_mask)
        else:
            new_decoder_states = decoder_states
            rev_d = None
        with tf.variable_scope("projection"):
            encoder_states_map = bn_dense_layer_v2(encoder_states, hn, True,
                                                   0., "encoder_states_map",
                                                   "linear", False, wd,
                                                   keep_prob, is_train)
            decoder_states_map = bn_dense_layer_v2(new_decoder_states, hn,
                                                   True, 0.,
                                                   "decoder_states_map",
                                                   "linear", False, wd,
                                                   keep_prob, is_train)
        with tf.variable_scope("bi_linear"):
            bilinear_pre = bn_dense_layer_v2(decoder_states_map, hn, False, 0.,
                                             "bilinear_map", "linear", False,
                                             wd, keep_prob, is_train)
            logits = tf.matmul(bilinear_pre,
                               encoder_states_map,
                               transpose_b=True)  # bs,dsl,esl

            if compressing:
                logits = decompress_seq_wrt_mask(logits, rev_d)

            return logits
Exemplo n.º 23
0
Arquivo: nn.py Projeto: mukundhs/Code
def residual_connection_with_dense(x,
                                   hn,
                                   bias,
                                   bias_start=0.0,
                                   scope=None,
                                   activation='relu',
                                   enable_bn=False,
                                   wd=0.,
                                   keep_prob=1.0,
                                   is_train=None,
                                   residual_keep_prob=1.):
    with tf.variable_scope(scope or 'residual_connection_with_dense'):
        y1 = bn_dense_layer_v2(x, hn, bias, bias_start, "dense_layer_1",
                               activation, enable_bn, wd, keep_prob, is_train)
        y2 = bn_dense_layer_v2(y1,
                               get_shape_list(x)[-1], bias, bias_start,
                               "dense_layer_2", "linear", enable_bn, wd,
                               keep_prob, is_train)

        return residual_connection(x, y2, is_train, residual_keep_prob,
                                   'residual_connection')
Exemplo n.º 24
0
    def _setup_training(self):
        self.logits_dict = self._build_network()
        self.loss, self.loss_dict = self._build_loss()
        self.prediction_dict = self._build_prediction()

        self.log_num_params()

        # to build train op
        self.train_op = optimization.create_optimizer(
            self.loss,
            self.cfg['learning_rate'],
            self.num_training_steps,
            int(self.num_training_steps * self.cfg['warmup_proportion']),
            use_tpu=False
        )

        self.run_dict = {
            "loss": self.loss,
            "loss_seq2seq": self.loss_dict["seq2seq"],
            "loss_seq_label": self.loss_dict["seq_label"],
            "train_op": self.train_op,
        }

        # for decoder beam search
        # # 1. for distribution
        seq2seq_dist_wo_pad = tf.nn.softmax(self.decoder_dict["logits_seq2seq_run"])  # bs,1,nl-1
        self.decoder_dict["distribution_seq2seq_run"] = tf.concat(  # bs,1,nl
            [
                tf.zeros(get_shape_list(seq2seq_dist_wo_pad)[:2] + [1]),  # bs,1,1
                seq2seq_dist_wo_pad,
            ], -1)
        self.decoder_dict["distribution_sketch_entity_run"] = tf.nn.softmax(
            self.decoder_dict["logits_sketch_entity_run"])
        self.decoder_dict["distribution_sketch_predicate_run"] = tf.nn.softmax(
            self.decoder_dict["logits_sketch_predicate_run"])
        self.decoder_dict["distribution_sketch_type_run"] = tf.nn.softmax(
            self.decoder_dict["logits_sketch_type_run"])
        self.decoder_dict["distribution_sketch_num_run"] = tf.nn.softmax(
            self.decoder_dict["logits_sketch_num_run"])
Exemplo n.º 25
0
def mask_matrix_to_coordinate(mask_mat, name=None):
    with tf.name_scope(name or "mask_matrix_to_coordinate"):
        bs, sll = get_shape_list(mask_mat, expected_rank=2)

        # lens
        real_lens = tf.reduce_sum(tf.cast(mask_mat, tf.int32), axis=-1)  # bs
        max_real_len = tf.reduce_max(real_lens, axis=0)  # []
        pad_lens = max_real_len - real_lens
        max_pad_len = tf.reduce_max(pad_lens, axis=0)

        # mask generation
        pad_mask_mat = generate_mask_based_on_lens(pad_lens, max_pad_len)
        coord_mask = generate_mask_based_on_lens(real_lens, max_real_len)

        # coord generation
        padded_mask_mat = tf.concat([mask_mat, pad_mask_mat], axis=-1)

        flat_coords = tf.where(padded_mask_mat)  # [bs*max_real_len,2]
        coords = tf.reshape(flat_coords,
                            [bs, max_real_len, 2])  # [bs,max_real_len]
        coords = mask_v3(coords, coord_mask, high_dim=True)
        return coords, coord_mask
Exemplo n.º 26
0
def direct_mask_generation(rep_mask, direct, attn_self, name=None):
    assert direct in ["forward", "backward"]
    with tf.name_scope(name or 'direct_mask_generation'):
        rep_shape = get_shape_list(rep_mask, 2)
        bs, sl = rep_shape
        # regular mask
        rep_mask_epd1 = tf.expand_dims(rep_mask, 1)  # bs,1,sl
        rep_mask_epd2 = tf.expand_dims(rep_mask, 2)  # bs,sl,1
        rep_mask_mat = tf.logical_and(rep_mask_epd1, rep_mask_epd2)  # bs,sl,sl

        # position mask
        sl_indices = tf.range(sl, dtype=tf.int32)
        sl_col, sl_row = tf.meshgrid(sl_indices, sl_indices)

        comp_func = tf.greater_equal if attn_self else tf.greater
        if direct == "forward":
            direct_mask = comp_func(sl_row, sl_col)  # sl,sl
        elif direct == "backward":
            direct_mask = comp_func(sl_col, sl_row)
        else:
            raise AttributeError
        direct_mask = tf.tile(tf.expand_dims(direct_mask, 0), [bs, 1, 1])

        return tf.logical_and(rep_mask_mat, direct_mask)
Exemplo n.º 27
0
    def __init__(self, cfg, tokenizer, data_type, labels_dict, max_sequence_len, num_training_steps, scope):
        if "level_for_dec" in cfg and cfg['level_for_dec'] >= 0:
            num_hidden_layers = cfg['level_for_dec'] + 1
        else:
            num_hidden_layers = None

        if "hidden_size_input" in cfg and cfg['hidden_size_input'] > 0:
            hidden_size = cfg['hidden_size_input']
        else:
            hidden_size = None

        if "num_attention_heads_input" in cfg and cfg['num_attention_heads_input'] > 0:
            num_attention_heads = cfg['num_attention_heads_input']
        else:
            num_attention_heads = None

        if "intermediate_size_input" in cfg and cfg['intermediate_size_input'] > 0:
            intermediate_size = cfg['intermediate_size_input']
        else:
            intermediate_size = None

        if "hidden_dropout_prob_input" in cfg and cfg['hidden_dropout_prob_input'] > 0:
            hidden_dropout_prob = cfg['hidden_dropout_prob_input']
        else:
            hidden_dropout_prob = None

        if "attention_probs_dropout_prob_input" in cfg and cfg['attention_probs_dropout_prob_input'] > 0:
            attention_probs_dropout_prob = cfg['attention_probs_dropout_prob_input']
        else:
            attention_probs_dropout_prob = None

        super(ModelBertTemplate, self).__init__(
            cfg, is_paired_data=False, scope=scope, num_hidden_layers=num_hidden_layers,
            hidden_size=hidden_size, num_attention_heads=num_attention_heads, intermediate_size=intermediate_size,
            hidden_dropout_prob=hidden_dropout_prob, attention_probs_dropout_prob=attention_probs_dropout_prob,
        )

        self.data_type = data_type
        self.labels_dict = labels_dict
        self.max_sequence_len = max_sequence_len
        self.num_training_steps = num_training_steps
        self.vocab = tokenizer.vocab
        self.tokenizer = tokenizer

        self.input_pos_ids = tf.placeholder(tf.int32, [None, None])
        self.loss_gain_wrt_qt = tf.placeholder(tf.float32, [None])
        # ==== an introduction to lengths =====
        # [prev_q] [sep] [prev_a] [sep1] [cur_q] [cls]
        # sl: seq len, wordpiece-level, ([prev_q] [sep] [prev_a] [sep1] [cur_q] [cls])
        # sll: seq label len, token-level, ([prev_q] [sep] [prev_a] [sep1] [cur_q])
        # asl: all seq len, token_level, ([prev_q] [sep] [prev_a] [sep1] [cur_q] [cls])
        # # others,
        # pl, piece len, the max len of word pieces belonging to a word

        # ====== labels =====
        # 1. EO
        self.num_EO_labels = len(labels_dict["EOs"]["labels"])
        self.num_type_labels = len(labels_dict["types"]["labels"])
        self.EO_label = tf.placeholder(tf.int32, [None, None])  # [bs, sll] with [0,nel)
        self.entity_type_label = tf.placeholder(tf.int32, [None, None])  # [bs, sll] with (1,ntl)

        # 2. Sketches: include sketch itself and leaves labels: entity, predicate, type and num
        self.sos_id = labels_dict["sketch"]["labels"].index(SOS_TOKEN)
        self.eos_id = labels_dict["sketch"]["labels"].index(EOS_TOKEN)
        self.num_predicate_labels = len(labels_dict["predicates"]["labels"])
        self.num_sketch_labels = len(labels_dict["sketch"]["labels"])
        self.sketch_label = tf.placeholder(tf.int32, [None, None])  # bs,dsl+1
        self.sketch_output_ids = self.sketch_label[:, 1:]  # bs,dsl
        self.sketch_mask = tf.cast(self.sketch_output_ids, tf.bool)  # bs,dsl
        self.sketch_input_ids = self.sketch_label[:, :-1] * tf.cast(self.sketch_mask, tf.int32)  # bs,dsl
        self.sketch_entity = tf.placeholder(tf.int32, [None, None])  # bs,dsl
        self.sketch_predicate = tf.placeholder(tf.int32, [None, None])  # bs,dsl
        self.sketch_type = tf.placeholder(tf.int32, [None, None])  # bs,dsl
        self.sketch_num = tf.placeholder(tf.int32, [None, None])  # bs,dsl
        # # 2.1 masks
        self.sketch_entity_mask = tf.not_equal(self.sketch_entity, -1)
        self.sketch_predicate_mask = tf.not_equal(self.sketch_predicate, 0)
        self.sketch_type_mask = tf.not_equal(self.sketch_type, 0)
        self.sketch_num_mask = tf.not_equal(self.sketch_num, -1)

        # lens
        self.asl = tf.reduce_max(self.input_pos_ids) + 1  # all sequence length (token-level)
        self.sll = get_shape_list(self.EO_label)[-1]  # sequence labeling length (token-level)
        self.wordpiece_idx = transform_pos_ids_to_wordpiece_idx(  # bs,asl
            self.input_pos_ids, self.input_mask, self.asl)
        self.pl = tf.reduce_max(self.wordpiece_idx) + 1

        # masks
        self.seq_label_mask = tf.cast(self.EO_label, bool)  # bs,sll
        self.wordpiece_mask = get_word_level_split(  # bs,asl,pl
            self.input_mask, self.input_pos_ids, self.wordpiece_idx, self.input_mask, self.asl, self.pl
        )

        # special token indices
        self.unk_id, self.cls_id, self.sep_id, self.empty_id, self.sep1_id, self.pad_id = convert_tokens_to_ids(
            self.vocab,
            [
                SPECIAL_TOKENS["UNK"], SPECIAL_TOKENS["CLS"], SPECIAL_TOKENS["SEP"],
                SPECIAL_TOKENS["EMPTY"], SPECIAL_TOKENS["SEP1"], SPECIAL_TOKENS["PAD"]])

        # for the decoding
        self.dec_input_emb_mat = tf.get_variable(
            "dec_input_emb_mat", [self.num_sketch_labels, self.cfg["hn"]],
            initializer=tf.truncated_normal_initializer(0, 0.05)
        )

        # for the key indices
        first_ids = get_word_level_split(  # bs,sl -> bs,asl,pl -> bs,asl
            self.input_ids, self.input_pos_ids, self.wordpiece_idx, self.input_mask, self.asl, self.pl
        )[..., 0]  # get the 1st id in each wordpieces
        self.sep_indices = tf.stack(
            get_key_indices(first_ids, [self.sep_id, self.sep1_id, self.cls_id]), axis=-1)

        self.decoder_dict = {
            # placeholders: don't forget the
            "encoder_states_placeholder": tf.placeholder(tf.float32, [None, None, cfg["hn"]]),  # bs,sl,hn
            "encoder_output_for_predicate_placeholder": tf.placeholder(tf.float32, [None, cfg["hn"]]),
            "encoder_output_for_type_placeholder": tf.placeholder(tf.float32, [None, cfg["hn"]]),
            "encoder_ids_placeholder": tf.placeholder(tf.float32, [None, None]),  # bs,sl
            "decoder_history_placeholder": tf.placeholder(tf.float32, [None, cfg["decoder_layer"], None, cfg["hn"]]),
            # bs,t,hn
            "decoder_ids_placeholder": tf.placeholder(tf.int32, [None, 1]),
            "is_training_placeholder": self.is_training,
            # intermediate tensor
            "encoder_states_run": None,
            "encoder_output_for_predicate_run": None,
            "encoder_output_for_type_run": None,
            "decoder_history_run": None,
            "logits_seq2seq_run": None,
            "logits_sketch_entity_run": None,
            "logits_sketch_predicate_run": None,
            "logits_sketch_type_run": None,
            "logits_sketch_num_run": None,
        }
        self.decoder_dict["encoder_mask"] = tf.cast(self.decoder_dict["encoder_ids_placeholder"], tf.bool)

        self.logits_dict = None
        self.loss_dict = None
        self.prediction_dict = None

        self.loss = None
        self.train_op = None

        self.run_dict = None

        self._setup_training()
Exemplo n.º 28
0
    def _build_loss(self):
        # for seq label
        joint_label = tf.where(  # 0 for empty or pad
            tf.logical_and(tf.greater_equal(self.EO_label, 2), tf.greater_equal(self.entity_type_label, 2)),
            (self.EO_label - 2) + 4 * (self.entity_type_label - 2) + 1,
            tf.zeros(get_shape_list(self.EO_label), tf.int32)
        )
        joint_label_rsp = tf.reshape(joint_label, [self.bs * self.sll])
        logits_seq_label_rsp = tf.reshape(self.logits_dict["seq_label"],
                                          [-1, get_shape_list(self.logits_dict["seq_label"])[-1]])

        losses_seq_label_rsp = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=joint_label_rsp, logits=logits_seq_label_rsp
        )
        losses_seq_label = tf.reshape(losses_seq_label_rsp, [self.bs, self.sll])
        seq_label_mask_tf = tf.cast(self.seq_label_mask, tf.float32)
        seq_label_weights = tf.where(
            tf.greater(joint_label, 0),
            tf.ones_like(losses_seq_label) * self.cfg["pos_gain"],
            tf.ones_like(losses_seq_label)
        ) * seq_label_mask_tf
        loss_seq_label = \
            tf.reduce_sum(losses_seq_label * seq_label_weights) / tf.reduce_sum(seq_label_weights)

        # for sequence to sequence
        # # 1. sketch loss
        label_seq2seq = tf.where(
            self.sketch_mask,
            self.sketch_output_ids - 1,  # for valid token - 1
            self.sketch_output_ids
        )
        losses_sketch = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=label_seq2seq,
            logits=self.logits_dict["seq2seq"]
        )

        # # 2. leaves losses
        # # # 2.1 entity
        losses_sketch_entity = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=self.sketch_entity * tf.cast(self.sketch_entity_mask, tf.int32),
            logits=self.logits_dict["sketch_entity"]
        ) * tf.cast(self.sketch_entity_mask, tf.float32)

        # # # 2.2 predicate
        losses_sketch_predicate = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=self.sketch_predicate,
            logits=self.logits_dict["sketch_predicate"]
        ) * tf.cast(self.sketch_predicate_mask, tf.float32)
        # # # 2.3 type
        losses_sketch_type = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=self.sketch_type,
            logits=self.logits_dict["sketch_type"]
        ) * tf.cast(self.sketch_type_mask, tf.float32)
        # # # 2.4 num
        losses_sketch_num = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=self.sketch_num * tf.cast(self.sketch_num_mask, tf.int32),
            logits=self.logits_dict["sketch_num"]
        ) * tf.cast(self.sketch_num_mask, tf.float32)
        # # 3 combine leaves' losses
        losses_sketch_leaves = \
            (losses_sketch_entity + losses_sketch_predicate + losses_sketch_type + losses_sketch_num) * \
            tf.cast(self.sketch_mask, tf.float32)
        # # 4. combine to the sketch loss
        losses_seq2seq = losses_sketch + losses_sketch_leaves
        # # 5. calc final loss
        sketch_mask_ft = tf.cast(self.sketch_mask, tf.float32)  # bs,sl
        sketch_mask_int = tf.cast(self.sketch_mask, tf.int32)  # bs,sl

        sketch_ex_mask = tf.cast(tf.reduce_sum(sketch_mask_int, -1), tf.bool)  # bs
        sketch_ex_mask_ft = tf.cast(sketch_ex_mask, tf.float32)  # bs

        seq_deno = tf.reduce_sum(sketch_mask_ft, -1)
        seq_deno = tf.where(
            tf.greater(seq_deno, 0.),
            seq_deno,
            tf.ones_like(seq_deno) * 1e-6,
        )
        loss_seq2seq_example = tf.reduce_sum(sketch_mask_ft * losses_seq2seq, -1) / seq_deno  # bs
        loss_seq2seq_example = loss_seq2seq_example * self.loss_gain_wrt_qt

        batch_deno = tf.reduce_sum(sketch_ex_mask_ft * self.loss_gain_wrt_qt)
        batch_deno = tf.where(
            tf.greater(batch_deno, 0.),
            batch_deno,
            tf.ones_like(batch_deno) * 1e-6,
        )
        loss_seq2seq = tf.reduce_sum(sketch_ex_mask_ft * loss_seq2seq_example) / batch_deno

        opt_loss = self.cfg["seq_label_loss_weight"]*loss_seq_label + \
                   self.cfg["seq2seq_loss_weight"] * loss_seq2seq
        return opt_loss, {
            "seq_label": loss_seq_label,
            "seq2seq": loss_seq2seq,
        }
Exemplo n.º 29
0
Arquivo: nn.py Projeto: mukundhs/Code
def bn_dense_layer_multi_head(input_tensor,
                              hn,
                              bias,
                              bias_start=0.0,
                              scope=None,
                              activation='relu',
                              enable_bn=False,
                              wd=0.,
                              keep_prob=1.0,
                              is_train=None,
                              dup_num=1,
                              merge_var=False):
    assert not enable_bn
    """The input could be >3-d and the 1d-for bs, 2d for head, -1d for hn"""

    act_fn = act_name2fn(activation)

    with tf.variable_scope(scope or 'bn_dense_layer_multi_head'):
        input_tensor = dropout(input_tensor, keep_prob,
                               is_train)  # dropout [bs,hd,sl,dim]
        # the comments using 4d [bs,hd,sl,dim] for example
        input_shape = get_shape_list(input_tensor)  # [4] for [bs,hd,sl,dim]
        assert len(input_shape) >= 3
        # exchange 1st and 2nd dimension
        perm_t = list(range(len(input_shape)))  # [0,1,2,3]
        perm_t[0], perm_t[1] = perm_t[1], perm_t[0]  # [1,0,2,3]
        input_tensor_t = tf.transpose(input_tensor, perm_t)  # [hd,bs,sl,dim]

        # merge and reshape
        input_shape_t = get_shape_list(
            input_tensor_t)  # [4] for [hd,bs,sl,dim]
        dims_merge = input_shape_t[1:-1]  # [2] for [bs,sl]
        new_dim = reduce(mul, dims_merge)  # bs*sl
        new_shape = [input_shape_t[0], new_dim,
                     input_shape_t[-1]]  # [3] for [hd,bs*sl,dim]
        input_tensor_rsp = tf.reshape(input_tensor_t,
                                      new_shape)  # [hd,bs*sl,dim]

        # dense layer
        hd_num = new_shape[0]  # head num
        hd_dim = new_shape[-1]  # head dim

        if merge_var:
            weight = tf.get_variable('W', shape=[hd_num, hd_dim, hn * dup_num])
        else:
            weight_list = []
            for i in range(hd_num):
                sub_weight_list = []
                for j in range(dup_num):
                    sub_weight_list.append(
                        tf.get_variable('W_%d_%d' % (i, j), shape=[hd_dim,
                                                                   hn]))
                weight_list.append(
                    tf.concat(sub_weight_list, -1
                              ) if dup_num > 1 else sub_weight_list[0])
            weight = tf.stack(weight_list, 0)

        out_rsp = tf.matmul(input_tensor_rsp, weight)  # hd_num, bs*sl, hn
        if bias:
            if merge_var:
                bias_val = tf.get_variable(
                    'bias',
                    shape=[hd_num, 1, hn],
                    dtype=tf.float32,
                    initializer=tf.constant_initializer(bias_start))
            else:
                bias_list = []
                for i in range(hd_num):
                    sub_bias_list = []
                    for j in range(dup_num):
                        sub_bias_list.append(
                            tf.get_variable(
                                'bias_%d_%d' % (i, j),
                                shape=[1, hn],
                                dtype=tf.float32,
                                initializer=tf.constant_initializer(
                                    bias_start)))
                    bias_list.append(
                        tf.concat(sub_bias_list, -1
                                  ) if dup_num > 1 else sub_bias_list[0])
                bias_val = tf.stack(bias_list, 0)
            out_rsp = out_rsp + bias_val  # hd_num, bs*sl, hn

        # un-merge
        output_shape_t = [new_shape[0]
                          ] + dims_merge + [hn]  # [4] for [hd,bs,sl,new_dim]
        output_t = tf.reshape(out_rsp, output_shape_t)  # [hd,bs,sl,new_dim]

        # transpose
        output = tf.transpose(output_t, perm_t)  # [bs,hd,sl,new_dim]

        if wd:
            tf.add_to_collection('reg_vars', weight)

        return act_fn(output)
Exemplo n.º 30
0
Arquivo: nn.py Projeto: mukundhs/Code
def bn_dense_layer_v2(input_tensor,
                      hn,
                      bias,
                      bias_start=0.0,
                      scope=None,
                      activation='relu',
                      enable_bn=False,
                      wd=0.,
                      keep_prob=1.0,
                      is_train=None,
                      dup_num=1,
                      merge_var=False):
    act_fn = act_name2fn(activation)
    with tf.variable_scope(scope or 'bn_dense_layer'):
        input_tensor = dropout(input_tensor, keep_prob, is_train)
        # the comment use a 3d tensor [bs,sl,hn] as a example
        input_shape = get_shape_list(input_tensor)  # [3]
        assert len(input_shape) >= 2  # at least [bs,hn]
        # merge
        dims_merge = input_shape[:-1]  # [all unrelated dims]
        new_dim = reduce(mul, dims_merge)  # get the merged dim
        new_shape = [new_dim, input_shape[-1]]  # new shape for matmul [2]
        input_tensor_rsp = tf.reshape(input_tensor, new_shape)  #  [xx,dim]

        # dense layer
        input_dim = new_shape[-1]
        if merge_var:
            weight = tf.get_variable('W',
                                     shape=[input_dim, hn * dup_num],
                                     dtype=tf.float32)
        else:
            weight_list = []
            for i in range(dup_num):
                weight_list.append(
                    tf.get_variable('W_%d' % i, shape=[input_dim, hn]))
            weight = tf.concat(weight_list, -1)
        output_rsp = tf.matmul(input_tensor_rsp, weight)

        if bias:
            if merge_var or dup_num == 1:
                bias_val = tf.get_variable(
                    'bias',
                    shape=[hn * dup_num],
                    dtype=tf.float32,
                    initializer=tf.constant_initializer(bias_start))
            else:
                bias_list = []
                for i in range(dup_num):
                    bias_list.append(
                        tf.get_variable(
                            'bias_%d' % i,
                            shape=[hn],
                            dtype=tf.float32,
                            initializer=tf.constant_initializer(bias_start)))
                bias_val = tf.concat(bias_list, -1)
            output_rsp += bias_val

        # output reshape
        output_shape = dims_merge + [hn * dup_num]  # [3] for [bs,sl,new_hn]
        output = tf.reshape(output_rsp, output_shape)  # [bs,sl,new_hn]

        if enable_bn:
            output = tf.contrib.layers.batch_norm(output,
                                                  center=True,
                                                  scale=True,
                                                  is_training=is_train,
                                                  updates_collections=None,
                                                  decay=0.9,
                                                  scope='bn')

        if wd:
            tf.add_to_collection('reg_vars', weight)

        return act_fn(output)