Esempio n. 1
0
    def _build_network_all_sketch_logits(
            self, decoder_states, encoder_states_for_decoder,
            encoder_mask, cls_state_predicate, cls_state_type,
            use_mask=False
    ):
        bs = get_shape_list(decoder_states)[0]
        if use_mask:
            entity_mask = tf.not_equal(self.sketch_entity, -1)
            predicate_mask = tf.not_equal(self.sketch_predicate, 0)
            type_mask = tf.not_equal(self.sketch_type, 0)
            num_mask = tf.not_equal(self.sketch_num, -1)
        else:
            entity_mask = None
            predicate_mask = None
            type_mask = None
            num_mask = None
        # bs,sl -----modify the last token to False
        encoder_wo_cls = tf.concat([  # [bs,sl]
            encoder_mask[:, 1:],  # [bs,sl-1]
            tf.cast(tf.zeros([get_shape_list(encoder_mask)[0], 1], tf.int32), tf.bool)  # [bs, 1]
        ], -1)

        logits_sketch_entity_pre = logits_for_sketch_index(  # bs,dsl,esl
            decoder_states, encoder_states_for_decoder, self.cfg["hn"], 0., 1 - self.cfg["clf_dropout"],
            self.is_training, compress_mask=entity_mask, scope="logits_sketch_entity_pre"
        )
        logits_sketch_entity = mask_v3(
            logits_sketch_entity_pre, encoder_wo_cls, multi_head=True, name="logits_sketch_entity")

        logits_sketch_predicate_pre = logits_for_sketch_prediction(
            decoder_states, cls_state_predicate, self.num_predicate_labels - 3,  self.cfg["hn"],
            self.cfg["clf_act_name"], 0., 1 - self.cfg["clf_dropout"], self.is_training,
            compress_mask=predicate_mask,
            scope="logits_sketch_predicate"
        )
        logits_sketch_predicate = tf.concat([
            tf.ones([bs, get_shape_list(logits_sketch_predicate_pre)[1], 3], tf.float32) * VERY_NEGATIVE_NUMBER,
            logits_sketch_predicate_pre,
        ], axis=-1)

        logits_sketch_type_pre = logits_for_sketch_prediction(
            decoder_states, cls_state_type, self.num_type_labels - 3, self.cfg["hn"],
            self.cfg["clf_act_name"], 0., 1 - self.cfg["clf_dropout"], self.is_training,
            compress_mask=type_mask,
            scope="logits_sketch_type"
        )
        logits_sketch_type = tf.concat([
            tf.ones([bs, get_shape_list(logits_sketch_type_pre)[1], 3], tf.float32) * VERY_NEGATIVE_NUMBER,
            logits_sketch_type_pre,
        ], axis=-1)

        logits_sketch_num_pre = logits_for_sketch_index(
            decoder_states, encoder_states_for_decoder, self.cfg["hn"], 0., 1 - self.cfg["clf_dropout"],
            self.is_training,  compress_mask=num_mask, scope="logits_sketch_num_pre"
        )
        logits_sketch_num = mask_v3(
            logits_sketch_num_pre, encoder_wo_cls, multi_head=True, name="logits_sketch_num")

        return logits_sketch_entity, logits_sketch_predicate, logits_sketch_type, logits_sketch_num
Esempio n. 2
0
def get_word_level_split(params, input_pos_ids, wordpiece_idx, input_mask, sll,
                         pl):
    # bs,sl,pl
    bs, sl = get_shape_list(input_pos_ids)

    higher_dim = len(get_shape_list(params)) > 2
    extra_dims = get_shape_list(params)[2:] if higher_dim else []

    # tf.tile(tf.expand_dims(tf.expand_dims(tf.range(bs), 1), 2), [1, sll, pl])
    bs_idxs = tf.tile(tf.expand_dims(tf.range(bs), 1), [1, sl])

    data_coord = tf.stack([bs_idxs, input_pos_ids, wordpiece_idx],
                          -1)  # [bs, sl, 3]
    # mask input_pos_ids and wordpiece_idx for -1
    mask_reversed_int = tf.cast(tf.logical_not(input_mask), tf.int32)
    data_coord = mask_v3(data_coord, input_mask, high_dim=True) + tf.stack(
        [
            mask_reversed_int * bs,
            mask_reversed_int * sll,
            mask_reversed_int * pl,
        ],
        axis=-1)

    # params's dtype check
    is_bool = (params.dtype == tf.bool)
    outputs = tf.scatter_nd(
        indices=data_coord,  # [bs, sl, 3]
        updates=params
        if not is_bool else tf.cast(params, tf.int32),  # [bs,sl]
        shape=[bs + 1, sll + 1, pl + 1] + extra_dims)
    if is_bool:
        outputs = tf.cast(outputs, tf.bool)

    outputs = outputs[:-1, :-1, :-1]
    return outputs
Esempio n. 3
0
def transform_pos_ids_to_wordpiece_idx(input_pos_ids, input_mask, sll):
    # 0 0 1 1 1 2 2 2 2 3 3 0 0 0 0 0  # bs,sl
    #
    bs, sl = get_shape_list(input_pos_ids)
    diff_pos = mask_v3(  # bs,sl
        input_pos_ids -
        tf.concat([tf.zeros([bs, 1], dtype=tf.int32), input_pos_ids[:, :-1]],
                  axis=1), input_mask)

    sl_idxs = tf.tile(tf.expand_dims(tf.range(sl, dtype=tf.int32), 0),
                      [bs, 1])  # bs,sl
    word_start_index = diff_pos * sl_idxs  # bs, sl
    # remove all 0 value
    slx_s = tf.reduce_sum(diff_pos,
                          axis=-1)  # the number of non-zero for each example
    slx = tf.reduce_max(slx_s)  #
    sly_s = slx - slx_s  # the number of non-zero for padding
    sly = tf.reduce_max(sly_s)  #
    padding_seq = tf.cast(generate_mask_based_on_lens(sly_s, sly), tf.int32)
    valid_data_mask = generate_mask_based_on_lens(slx_s, slx)  # bs, slx

    padded_word_start_index = tf.concat([word_start_index, padding_seq],
                                        axis=-1)  # bs,sl+sly

    data_coord = tf.reshape(  # bs, slx
        tf.where(tf.cast(padded_word_start_index, tf.bool)),  # bs*slx,2
        [bs, slx, 2])

    word_start = tf.concat(  # bs, sll
        [
            tf.zeros([bs, 1], dtype=tf.int32),
            mask_v3(tf.gather_nd(padded_word_start_index, data_coord),
                    valid_data_mask),  # bs,slx
            tf.zeros([bs, sll - slx - 1], dtype=tf.int32)
        ],
        axis=1)

    bs_idxs = generate_seq_idxs(bs, sl, transpose=True)  # bs,sl
    base_coord = tf.stack([bs_idxs, input_pos_ids], axis=-1)  # bs,sl,2
    base_value = tf.gather_nd(word_start, base_coord)  # bs,sl

    # finally
    outputs = mask_v3(sl_idxs - base_value, input_mask)  # bs,sl
    return outputs
Esempio n. 4
0
def generate_label_mask(input_pos_ids, input_mask, sll):
    input_pos_ids = mask_v3(input_pos_ids + 1, input_mask)

    bs, sl = get_shape_list(input_pos_ids)

    sll_idxs = tf.tile(tf.expand_dims(tf.range(sll, dtype=tf.int32), 0),
                       [bs, 1])  # bs,sl
    max_idxs = tf.reduce_max(input_pos_ids, axis=-1, keepdims=True)  # [bs,1]

    return tf.less(sll_idxs, max_idxs)
Esempio n. 5
0
def decompress_seq_wrt_mask(tensor_input, reverse_dict):
    bs, tgt_len, hn = get_shape_list(tensor_input)
    src_len = get_shape_list(reverse_dict["src_mask"])[1]

    padded_tensor = tf.scatter_nd(reverse_dict["coord"], tensor_input,
                                  [bs, src_len + 1, hn])
    out_tensor = padded_tensor[:, :-1]  # bs,src_len,hn

    masked_out_tensor = mask_v3(out_tensor,
                                reverse_dict["src_mask"],
                                high_dim=True)
    return masked_out_tensor
Esempio n. 6
0
def compress_seq_wrt_mask(tensor_input, tensor_mask):

    bs, sl, hn = get_shape_list(tensor_input)

    seq_lens = tf.reduce_sum(tf.cast(tensor_mask, tf.int32), -1)  # sl
    max_len = tf.reduce_max(seq_lens)  # []
    new_mask = generate_mask_based_on_lens(seq_lens, max_len)

    # ======> to ensure every batch get same elem via padding
    pad_lens = max_len - seq_lens
    max_pad_len = tf.reduce_max(pad_lens)
    pad_mask = generate_mask_based_on_lens(pad_lens, max_pad_len)

    padded_tensor_mask = tf.concat([tensor_mask, pad_mask],
                                   axis=-1)  # bs,sl+max_pad_len
    # new coord
    bs_idxs = generate_seq_idxs(bs, sl + max_pad_len,
                                transpose=True)  # bs,sl+max_pad_len
    sl_idxs = tf.concat(  # bs,sl+max_pad_len
        [
            generate_seq_idxs(bs, sl, transpose=False),  # bs,sl
            -tf.ones([bs, max_pad_len], tf.int32)  # bs, max_pad_len
        ],
        axis=-1)
    data_coord_map = tf.stack([bs_idxs, sl_idxs],
                              axis=-1)  # bs,sl+max_pad_len,2

    padded_coord = tf.where(padded_tensor_mask)  # bs*max_len,2

    mapped_padded_coord_rsp = tf.gather_nd(data_coord_map,
                                           padded_coord)  # bs*max_len,2
    mapped_padded_coord = tf.reshape(mapped_padded_coord_rsp,
                                     [bs, max_len, 2])  # bs,max_len,2

    gathered_data = tf.gather_nd(tensor_input,
                                 mapped_padded_coord)  # bs,max_len,hn
    masked_gathered_data = mask_v3(gathered_data, new_mask, high_dim=True)

    reverse_dict = {
        "src_mask": tensor_mask,
        "tgt_mask": new_mask,
        "coord": mapped_padded_coord,  # bs,max_len,2
    }

    return masked_gathered_data, new_mask, reverse_dict
Esempio n. 7
0
File: nn.py Progetto: mukundhs/Code
def pooling_with_mask(rep_tensor, rep_mask, method='max', scope=None):
    # rep_tensor have one more rank than rep_mask
    with tf.name_scope(scope or '%s_pooling' % method):
        if method == 'max':
            rep_tensor_masked = exp_mask_v3(rep_tensor,
                                            rep_mask,
                                            high_dim=True)
            output = tf.reduce_max(rep_tensor_masked, -2)
        elif method == 'mean':
            rep_tensor_masked = mask_v3(rep_tensor, rep_mask,
                                        high_dim=True)  # [...,sl,hn]
            rep_sum = tf.reduce_sum(rep_tensor_masked, -2)  # [..., hn]
            denominator = tf.reduce_sum(tf.cast(rep_mask, tf.int32), -1,
                                        True)  # [..., 1]
            denominator = tf.where(
                tf.equal(denominator, tf.zeros_like(denominator, tf.int32)),
                tf.ones_like(denominator, tf.int32), denominator)
            output = rep_sum / tf.cast(denominator, tf.float32)
        else:
            raise AttributeError('No Pooling method name as %s' % method)
        return output
Esempio n. 8
0
    def _build_network_seq_label_logits(self, encoder_states):
        wp_features = get_word_level_split(  # bs,sl,hn -> bs,asl,pl,hn
            encoder_states, self.input_pos_ids, self.wordpiece_idx, self.input_mask, self.asl, self.pl
        )

        all_token_features = s2t_self_attn(  # bs,asl,hn
            wp_features, self.wordpiece_mask, self.cfg['clf_act_name'], 'multi_dim',
            0., 1.-self.cfg['clf_dropout'], self.is_training, 'all_token_features',
        )
        # get seq_label_token_features  asl -> sll (asl-1)
        seq_label_token_features = mask_v3(  # remove the latest feature
            all_token_features[:, :-1], self.seq_label_mask, high_dim=True
        )

        with tf.variable_scope("output"):
            with tf.variable_scope("seq_labeling"):
                seq_label_logits = bn_dense_layer_v2(  # "O"  (NO PAD   for predicate no empty no pad
                    seq_label_token_features, 1 + (self.num_EO_labels-2) * (self.num_type_labels-2),
                    True, 0., "seq_labeling_logits", "linear", False,
                    0., 1. - self.cfg['clf_dropout'], self.is_training
                )
        return seq_label_logits
Esempio n. 9
0
def mask_matrix_to_coordinate(mask_mat, name=None):
    with tf.name_scope(name or "mask_matrix_to_coordinate"):
        bs, sll = get_shape_list(mask_mat, expected_rank=2)

        # lens
        real_lens = tf.reduce_sum(tf.cast(mask_mat, tf.int32), axis=-1)  # bs
        max_real_len = tf.reduce_max(real_lens, axis=0)  # []
        pad_lens = max_real_len - real_lens
        max_pad_len = tf.reduce_max(pad_lens, axis=0)

        # mask generation
        pad_mask_mat = generate_mask_based_on_lens(pad_lens, max_pad_len)
        coord_mask = generate_mask_based_on_lens(real_lens, max_real_len)

        # coord generation
        padded_mask_mat = tf.concat([mask_mat, pad_mask_mat], axis=-1)

        flat_coords = tf.where(padded_mask_mat)  # [bs*max_real_len,2]
        coords = tf.reshape(flat_coords,
                            [bs, max_real_len, 2])  # [bs,max_real_len]
        coords = mask_v3(coords, coord_mask, high_dim=True)
        return coords, coord_mask
Esempio n. 10
0
    def _build_prediction(self):
        # # for NER sequence labeling
        predictions_seq_label = tf.cast(tf.argmax(self.logits_dict["seq_label"], axis=-1), tf.int32)
        predictions_ner = tf.where(
            tf.greater_equal(predictions_seq_label, 1),
            tf.mod(predictions_seq_label - 1, 4) + 2,
            tf.ones_like(predictions_seq_label)
        )
        predictions_ner = mask_v3(predictions_ner, self.seq_label_mask)
        # # for predicate
        predictions_entity_type = tf.where(
            tf.greater_equal(predictions_seq_label, 1),
            tf.cast((predictions_seq_label - 1) / 4, tf.int32) + 2,
            tf.ones_like(predictions_seq_label)
        )
        predictions_entity_type = mask_v3(predictions_entity_type, self.seq_label_mask)

        # # for semantic parsing
        predicted_seq2seq = tf.cast(tf.argmax(self.logits_dict["seq2seq"], axis=-1), tf.int32)  # bs,sl
        predicted_seq2seq = tf.where(
            self.sketch_mask,
            predicted_seq2seq + 1,
            tf.zeros_like(predicted_seq2seq)
        )

        predicted_sketch_entity = tf.cast(tf.argmax(self.logits_dict["sketch_entity"], axis=-1), tf.int32)  # bs,sl
        predicted_sketch_entity = tf.where(
            self.sketch_entity_mask,
            predicted_sketch_entity,
            -tf.ones_like(predicted_sketch_entity)
        )

        predicted_sketch_predicate = tf.cast(tf.argmax(self.logits_dict["sketch_predicate"], axis=-1), tf.int32)  # bs,sl
        predicted_sketch_predicate = tf.where(
            self.sketch_predicate_mask,
            predicted_sketch_predicate,
            tf.zeros_like(predicted_sketch_predicate)
        )

        predicted_sketch_type = tf.cast(tf.argmax(self.logits_dict["sketch_type"], axis=-1), tf.int32)  # bs,sl
        predicted_sketch_type = tf.where(
            self.sketch_type_mask,
            predicted_sketch_type,
            tf.zeros_like(predicted_sketch_type)
        )

        predicted_sketch_num = tf.cast(tf.argmax(self.logits_dict["sketch_num"], axis=-1), tf.int32)  # bs,sl
        predicted_sketch_num = tf.where(
            self.sketch_num_mask,
            predicted_sketch_num,
            -tf.ones_like(predicted_sketch_num)
        )

        return {
            "EOs": predictions_ner,
            "entity_types": predictions_entity_type,
            "seq_label_mask": self.seq_label_mask,
            "sketch": predicted_seq2seq,
            "sketch_entity": predicted_sketch_entity,
            "sketch_predicate": predicted_sketch_predicate,
            "sketch_type": predicted_sketch_type,
            "sketch_num": predicted_sketch_num,

            # aux
            "sep_indices": self.sep_indices,
        }
Esempio n. 11
0
def s2t_self_attn(
        tensor_input, tensor_mask, deep_act=None, method='multi_dim',
        wd=0., keep_prob=1., is_training=None,
        scope=None, **kwargs
):
    use_deep = isinstance(deep_act, str)  # use Two layers or Single layer for the alignment score
    with tf.variable_scope(scope or 's2t_self_attn_{}'.format(method)):
        tensor_shape = get_shape_list(tensor_input)
        hn = tensor_shape[-1]  # hidden state number

        if method == 'additive':
            align_scores = bn_dense_layer_v2(  # bs,sl,hn/1
                tensor_input, hn if use_deep else 1, True, 0., 'align_score_1', 'linear', False,
                wd, keep_prob, is_training
            )
            if use_deep:
                align_scores = bn_dense_layer_v2(  # bs,sl,1
                    act_name2fn(deep_act)(align_scores), 1, True, 0., 'align_score_2', 'linear', False,
                    wd, keep_prob, is_training
                )
        elif method == 'multi_dim':
            align_scores = bn_dense_layer_v2(  # bs,sl,hn
                tensor_input, hn, False, 0., 'align_score_1', 'linear', False,
                wd, keep_prob, is_training
            )
            if use_deep:
                align_scores = bn_dense_layer_v2(  # bs,sl,hn
                    act_name2fn(deep_act)(align_scores), hn, True, 0., 'align_score_2', 'linear', False,
                    wd, keep_prob, is_training
                )
        elif method == 'multi_dim_head':
            get_shape_list(tensor_input, expected_rank=3)  # the input should be rank-3
            assert 'head_num' in kwargs and isinstance(kwargs['head_num'], int)
            head_num = kwargs['head_num']
            assert hn % head_num == 0
            head_dim = hn // head_num

            tensor_input_heads = split_head(tensor_input, head_num)  # [bs,hd,sl,hd_dim]

            align_scores_heads = bn_dense_layer_multi_head(  # [bs,hd,sl,hd_dim]
                tensor_input_heads, head_dim, True, 0., 'align_scores_heads_1', 'linear', False,
                wd, keep_prob, is_training
            )
            if use_deep:
                align_scores_heads = bn_dense_layer_multi_head(  # [bs,hd,sl,hd_dim]
                    act_name2fn(deep_act)(align_scores_heads), head_dim,
                    True, 0., 'align_scores_heads_2', 'linear', False,
                    wd, keep_prob, is_training
                )
            align_scores = combine_head(align_scores_heads)  # [bs,sl,dim]
        else:
            raise AttributeError

        # attention procedure align_scores [bs,sl,1/dim]
        align_scores_masked = exp_mask_v3(align_scores, tensor_mask, multi_head=False, high_dim=True)  # bs,sl,hn
        attn_prob = tf.nn.softmax(align_scores_masked, axis=-2)  # bs,sl,hn

        if 'attn_keep_prob' in kwargs and isinstance(kwargs['attn_keep_prob'], float):
            attn_prob = dropout(attn_prob, kwargs['attn_keep_prob'], is_training)  # bs,sl,hn

        attn_res = tf.reduce_sum(  # [bs,sl,hn] -> [bs,dim]
            mask_v3(attn_prob*tensor_input, tensor_mask, high_dim=True), axis=-2
        )

        return attn_res  # [bs,hn]
Esempio n. 12
0
def cond_attn(
        pairwise_scores, featurewise_scores, value_features, from_mask, to_mask,
        attn_keep_prob=1., is_training=None,
        extra_pairwise_mask=None, name=None
):
    """

    :param pairwise_scores: [bs,[head],slf,slt]
    :param featurewise_scores:  [bs,[head],slt,hn]
    :param value_features:  [bs,[head],slt,hn]
    :param from_mask:
    :param to_mask:
    :param extra_pairwise_mask:
    :return:
    """
    with tf.name_scope(name or 'cond_attn'):
        # sanity check
        pairwise_shape = get_shape_list(pairwise_scores)
        featurewise_shape = get_shape_list(featurewise_scores)
        value_shape = get_shape_list(value_features)

        pairwise_ndim = len(pairwise_shape)
        featurewise_ndim = len(featurewise_shape)
        value_ndim = len(value_shape)

        assert featurewise_shape[-1] == value_shape[-1]
        assert pairwise_ndim in [3, 4] and pairwise_ndim == featurewise_ndim and featurewise_ndim == value_ndim

        multi_head = True if pairwise_ndim == 4 else False  # if the multi-head included

        cross_attn_mask = cross_attn_mask_generation(  # [bs,slf,slt]
            from_mask, to_mask, mutual=True
        )

        if multi_head:  # add the multi-head dim
            cross_attn_mask = tf.expand_dims(cross_attn_mask, 1)  # [bs,[1],slf,slt]

        if not isinstance(extra_pairwise_mask, type(None)):
            # the extra_pairwise_mask could be include the multi-head
            extra_pairwise_mask_shape = get_shape_list(extra_pairwise_mask)
            assert len(extra_pairwise_mask_shape) in [3, 4]

            assert multi_head or len(extra_pairwise_mask_shape) == 3  # if multi_head=False, shape must be 3-D

            if multi_head and len(extra_pairwise_mask_shape) == 3:
                extra_pairwise_mask = tf.expand_dims(cross_attn_mask, 1)  # [bs,[1],slf,slt]

            cross_attn_mask = tf.logical_and(cross_attn_mask, extra_pairwise_mask)  # [bs,[1],slf,slt]

        e_dot_logits = mask_v3(  # bs,head,sl1,sl2
            tf.exp(pairwise_scores), cross_attn_mask, multi_head=False, high_dim=False)  # the multi-head has been add

        e_multi_logits = mask_v3(
            tf.exp(featurewise_scores), to_mask, multi_head=multi_head, high_dim=True
        )

        with tf.name_scope("hybrid_attn"):
            # Z: softmax normalization term in attention probabilities calculation
            accum_z_deno = tf.matmul(e_dot_logits, e_multi_logits)  # num,bs,sl,dim
            accum_z_deno = tf.where(  # in case of NaN and Inf
                tf.greater(accum_z_deno, tf.zeros_like(accum_z_deno)),
                accum_z_deno,
                tf.ones_like(accum_z_deno)
            )
            # attention dropout
            e_dot_logits = dropout(e_dot_logits, math.sqrt(attn_keep_prob), is_training)
            e_multi_logits = dropout(e_multi_logits, math.sqrt(attn_keep_prob), is_training)
            # sum of exp(logits) \multiply attention target sequence
            rep_mul_score = value_features * e_multi_logits
            accum_rep_mul_score = tf.matmul(e_dot_logits, rep_mul_score)
            # calculate the final attention results
            attn_res = accum_rep_mul_score / accum_z_deno

        if multi_head:
            attn_res = combine_head(attn_res)  # [bs,slf,hd_num*hd_dim]

    return attn_res  # [bs,slf,hn/hd_num*hd_dim]