def _build_network_all_sketch_logits( self, decoder_states, encoder_states_for_decoder, encoder_mask, cls_state_predicate, cls_state_type, use_mask=False ): bs = get_shape_list(decoder_states)[0] if use_mask: entity_mask = tf.not_equal(self.sketch_entity, -1) predicate_mask = tf.not_equal(self.sketch_predicate, 0) type_mask = tf.not_equal(self.sketch_type, 0) num_mask = tf.not_equal(self.sketch_num, -1) else: entity_mask = None predicate_mask = None type_mask = None num_mask = None # bs,sl -----modify the last token to False encoder_wo_cls = tf.concat([ # [bs,sl] encoder_mask[:, 1:], # [bs,sl-1] tf.cast(tf.zeros([get_shape_list(encoder_mask)[0], 1], tf.int32), tf.bool) # [bs, 1] ], -1) logits_sketch_entity_pre = logits_for_sketch_index( # bs,dsl,esl decoder_states, encoder_states_for_decoder, self.cfg["hn"], 0., 1 - self.cfg["clf_dropout"], self.is_training, compress_mask=entity_mask, scope="logits_sketch_entity_pre" ) logits_sketch_entity = mask_v3( logits_sketch_entity_pre, encoder_wo_cls, multi_head=True, name="logits_sketch_entity") logits_sketch_predicate_pre = logits_for_sketch_prediction( decoder_states, cls_state_predicate, self.num_predicate_labels - 3, self.cfg["hn"], self.cfg["clf_act_name"], 0., 1 - self.cfg["clf_dropout"], self.is_training, compress_mask=predicate_mask, scope="logits_sketch_predicate" ) logits_sketch_predicate = tf.concat([ tf.ones([bs, get_shape_list(logits_sketch_predicate_pre)[1], 3], tf.float32) * VERY_NEGATIVE_NUMBER, logits_sketch_predicate_pre, ], axis=-1) logits_sketch_type_pre = logits_for_sketch_prediction( decoder_states, cls_state_type, self.num_type_labels - 3, self.cfg["hn"], self.cfg["clf_act_name"], 0., 1 - self.cfg["clf_dropout"], self.is_training, compress_mask=type_mask, scope="logits_sketch_type" ) logits_sketch_type = tf.concat([ tf.ones([bs, get_shape_list(logits_sketch_type_pre)[1], 3], tf.float32) * VERY_NEGATIVE_NUMBER, logits_sketch_type_pre, ], axis=-1) logits_sketch_num_pre = logits_for_sketch_index( decoder_states, encoder_states_for_decoder, self.cfg["hn"], 0., 1 - self.cfg["clf_dropout"], self.is_training, compress_mask=num_mask, scope="logits_sketch_num_pre" ) logits_sketch_num = mask_v3( logits_sketch_num_pre, encoder_wo_cls, multi_head=True, name="logits_sketch_num") return logits_sketch_entity, logits_sketch_predicate, logits_sketch_type, logits_sketch_num
def get_word_level_split(params, input_pos_ids, wordpiece_idx, input_mask, sll, pl): # bs,sl,pl bs, sl = get_shape_list(input_pos_ids) higher_dim = len(get_shape_list(params)) > 2 extra_dims = get_shape_list(params)[2:] if higher_dim else [] # tf.tile(tf.expand_dims(tf.expand_dims(tf.range(bs), 1), 2), [1, sll, pl]) bs_idxs = tf.tile(tf.expand_dims(tf.range(bs), 1), [1, sl]) data_coord = tf.stack([bs_idxs, input_pos_ids, wordpiece_idx], -1) # [bs, sl, 3] # mask input_pos_ids and wordpiece_idx for -1 mask_reversed_int = tf.cast(tf.logical_not(input_mask), tf.int32) data_coord = mask_v3(data_coord, input_mask, high_dim=True) + tf.stack( [ mask_reversed_int * bs, mask_reversed_int * sll, mask_reversed_int * pl, ], axis=-1) # params's dtype check is_bool = (params.dtype == tf.bool) outputs = tf.scatter_nd( indices=data_coord, # [bs, sl, 3] updates=params if not is_bool else tf.cast(params, tf.int32), # [bs,sl] shape=[bs + 1, sll + 1, pl + 1] + extra_dims) if is_bool: outputs = tf.cast(outputs, tf.bool) outputs = outputs[:-1, :-1, :-1] return outputs
def transform_pos_ids_to_wordpiece_idx(input_pos_ids, input_mask, sll): # 0 0 1 1 1 2 2 2 2 3 3 0 0 0 0 0 # bs,sl # bs, sl = get_shape_list(input_pos_ids) diff_pos = mask_v3( # bs,sl input_pos_ids - tf.concat([tf.zeros([bs, 1], dtype=tf.int32), input_pos_ids[:, :-1]], axis=1), input_mask) sl_idxs = tf.tile(tf.expand_dims(tf.range(sl, dtype=tf.int32), 0), [bs, 1]) # bs,sl word_start_index = diff_pos * sl_idxs # bs, sl # remove all 0 value slx_s = tf.reduce_sum(diff_pos, axis=-1) # the number of non-zero for each example slx = tf.reduce_max(slx_s) # sly_s = slx - slx_s # the number of non-zero for padding sly = tf.reduce_max(sly_s) # padding_seq = tf.cast(generate_mask_based_on_lens(sly_s, sly), tf.int32) valid_data_mask = generate_mask_based_on_lens(slx_s, slx) # bs, slx padded_word_start_index = tf.concat([word_start_index, padding_seq], axis=-1) # bs,sl+sly data_coord = tf.reshape( # bs, slx tf.where(tf.cast(padded_word_start_index, tf.bool)), # bs*slx,2 [bs, slx, 2]) word_start = tf.concat( # bs, sll [ tf.zeros([bs, 1], dtype=tf.int32), mask_v3(tf.gather_nd(padded_word_start_index, data_coord), valid_data_mask), # bs,slx tf.zeros([bs, sll - slx - 1], dtype=tf.int32) ], axis=1) bs_idxs = generate_seq_idxs(bs, sl, transpose=True) # bs,sl base_coord = tf.stack([bs_idxs, input_pos_ids], axis=-1) # bs,sl,2 base_value = tf.gather_nd(word_start, base_coord) # bs,sl # finally outputs = mask_v3(sl_idxs - base_value, input_mask) # bs,sl return outputs
def generate_label_mask(input_pos_ids, input_mask, sll): input_pos_ids = mask_v3(input_pos_ids + 1, input_mask) bs, sl = get_shape_list(input_pos_ids) sll_idxs = tf.tile(tf.expand_dims(tf.range(sll, dtype=tf.int32), 0), [bs, 1]) # bs,sl max_idxs = tf.reduce_max(input_pos_ids, axis=-1, keepdims=True) # [bs,1] return tf.less(sll_idxs, max_idxs)
def decompress_seq_wrt_mask(tensor_input, reverse_dict): bs, tgt_len, hn = get_shape_list(tensor_input) src_len = get_shape_list(reverse_dict["src_mask"])[1] padded_tensor = tf.scatter_nd(reverse_dict["coord"], tensor_input, [bs, src_len + 1, hn]) out_tensor = padded_tensor[:, :-1] # bs,src_len,hn masked_out_tensor = mask_v3(out_tensor, reverse_dict["src_mask"], high_dim=True) return masked_out_tensor
def compress_seq_wrt_mask(tensor_input, tensor_mask): bs, sl, hn = get_shape_list(tensor_input) seq_lens = tf.reduce_sum(tf.cast(tensor_mask, tf.int32), -1) # sl max_len = tf.reduce_max(seq_lens) # [] new_mask = generate_mask_based_on_lens(seq_lens, max_len) # ======> to ensure every batch get same elem via padding pad_lens = max_len - seq_lens max_pad_len = tf.reduce_max(pad_lens) pad_mask = generate_mask_based_on_lens(pad_lens, max_pad_len) padded_tensor_mask = tf.concat([tensor_mask, pad_mask], axis=-1) # bs,sl+max_pad_len # new coord bs_idxs = generate_seq_idxs(bs, sl + max_pad_len, transpose=True) # bs,sl+max_pad_len sl_idxs = tf.concat( # bs,sl+max_pad_len [ generate_seq_idxs(bs, sl, transpose=False), # bs,sl -tf.ones([bs, max_pad_len], tf.int32) # bs, max_pad_len ], axis=-1) data_coord_map = tf.stack([bs_idxs, sl_idxs], axis=-1) # bs,sl+max_pad_len,2 padded_coord = tf.where(padded_tensor_mask) # bs*max_len,2 mapped_padded_coord_rsp = tf.gather_nd(data_coord_map, padded_coord) # bs*max_len,2 mapped_padded_coord = tf.reshape(mapped_padded_coord_rsp, [bs, max_len, 2]) # bs,max_len,2 gathered_data = tf.gather_nd(tensor_input, mapped_padded_coord) # bs,max_len,hn masked_gathered_data = mask_v3(gathered_data, new_mask, high_dim=True) reverse_dict = { "src_mask": tensor_mask, "tgt_mask": new_mask, "coord": mapped_padded_coord, # bs,max_len,2 } return masked_gathered_data, new_mask, reverse_dict
def pooling_with_mask(rep_tensor, rep_mask, method='max', scope=None): # rep_tensor have one more rank than rep_mask with tf.name_scope(scope or '%s_pooling' % method): if method == 'max': rep_tensor_masked = exp_mask_v3(rep_tensor, rep_mask, high_dim=True) output = tf.reduce_max(rep_tensor_masked, -2) elif method == 'mean': rep_tensor_masked = mask_v3(rep_tensor, rep_mask, high_dim=True) # [...,sl,hn] rep_sum = tf.reduce_sum(rep_tensor_masked, -2) # [..., hn] denominator = tf.reduce_sum(tf.cast(rep_mask, tf.int32), -1, True) # [..., 1] denominator = tf.where( tf.equal(denominator, tf.zeros_like(denominator, tf.int32)), tf.ones_like(denominator, tf.int32), denominator) output = rep_sum / tf.cast(denominator, tf.float32) else: raise AttributeError('No Pooling method name as %s' % method) return output
def _build_network_seq_label_logits(self, encoder_states): wp_features = get_word_level_split( # bs,sl,hn -> bs,asl,pl,hn encoder_states, self.input_pos_ids, self.wordpiece_idx, self.input_mask, self.asl, self.pl ) all_token_features = s2t_self_attn( # bs,asl,hn wp_features, self.wordpiece_mask, self.cfg['clf_act_name'], 'multi_dim', 0., 1.-self.cfg['clf_dropout'], self.is_training, 'all_token_features', ) # get seq_label_token_features asl -> sll (asl-1) seq_label_token_features = mask_v3( # remove the latest feature all_token_features[:, :-1], self.seq_label_mask, high_dim=True ) with tf.variable_scope("output"): with tf.variable_scope("seq_labeling"): seq_label_logits = bn_dense_layer_v2( # "O" (NO PAD for predicate no empty no pad seq_label_token_features, 1 + (self.num_EO_labels-2) * (self.num_type_labels-2), True, 0., "seq_labeling_logits", "linear", False, 0., 1. - self.cfg['clf_dropout'], self.is_training ) return seq_label_logits
def mask_matrix_to_coordinate(mask_mat, name=None): with tf.name_scope(name or "mask_matrix_to_coordinate"): bs, sll = get_shape_list(mask_mat, expected_rank=2) # lens real_lens = tf.reduce_sum(tf.cast(mask_mat, tf.int32), axis=-1) # bs max_real_len = tf.reduce_max(real_lens, axis=0) # [] pad_lens = max_real_len - real_lens max_pad_len = tf.reduce_max(pad_lens, axis=0) # mask generation pad_mask_mat = generate_mask_based_on_lens(pad_lens, max_pad_len) coord_mask = generate_mask_based_on_lens(real_lens, max_real_len) # coord generation padded_mask_mat = tf.concat([mask_mat, pad_mask_mat], axis=-1) flat_coords = tf.where(padded_mask_mat) # [bs*max_real_len,2] coords = tf.reshape(flat_coords, [bs, max_real_len, 2]) # [bs,max_real_len] coords = mask_v3(coords, coord_mask, high_dim=True) return coords, coord_mask
def _build_prediction(self): # # for NER sequence labeling predictions_seq_label = tf.cast(tf.argmax(self.logits_dict["seq_label"], axis=-1), tf.int32) predictions_ner = tf.where( tf.greater_equal(predictions_seq_label, 1), tf.mod(predictions_seq_label - 1, 4) + 2, tf.ones_like(predictions_seq_label) ) predictions_ner = mask_v3(predictions_ner, self.seq_label_mask) # # for predicate predictions_entity_type = tf.where( tf.greater_equal(predictions_seq_label, 1), tf.cast((predictions_seq_label - 1) / 4, tf.int32) + 2, tf.ones_like(predictions_seq_label) ) predictions_entity_type = mask_v3(predictions_entity_type, self.seq_label_mask) # # for semantic parsing predicted_seq2seq = tf.cast(tf.argmax(self.logits_dict["seq2seq"], axis=-1), tf.int32) # bs,sl predicted_seq2seq = tf.where( self.sketch_mask, predicted_seq2seq + 1, tf.zeros_like(predicted_seq2seq) ) predicted_sketch_entity = tf.cast(tf.argmax(self.logits_dict["sketch_entity"], axis=-1), tf.int32) # bs,sl predicted_sketch_entity = tf.where( self.sketch_entity_mask, predicted_sketch_entity, -tf.ones_like(predicted_sketch_entity) ) predicted_sketch_predicate = tf.cast(tf.argmax(self.logits_dict["sketch_predicate"], axis=-1), tf.int32) # bs,sl predicted_sketch_predicate = tf.where( self.sketch_predicate_mask, predicted_sketch_predicate, tf.zeros_like(predicted_sketch_predicate) ) predicted_sketch_type = tf.cast(tf.argmax(self.logits_dict["sketch_type"], axis=-1), tf.int32) # bs,sl predicted_sketch_type = tf.where( self.sketch_type_mask, predicted_sketch_type, tf.zeros_like(predicted_sketch_type) ) predicted_sketch_num = tf.cast(tf.argmax(self.logits_dict["sketch_num"], axis=-1), tf.int32) # bs,sl predicted_sketch_num = tf.where( self.sketch_num_mask, predicted_sketch_num, -tf.ones_like(predicted_sketch_num) ) return { "EOs": predictions_ner, "entity_types": predictions_entity_type, "seq_label_mask": self.seq_label_mask, "sketch": predicted_seq2seq, "sketch_entity": predicted_sketch_entity, "sketch_predicate": predicted_sketch_predicate, "sketch_type": predicted_sketch_type, "sketch_num": predicted_sketch_num, # aux "sep_indices": self.sep_indices, }
def s2t_self_attn( tensor_input, tensor_mask, deep_act=None, method='multi_dim', wd=0., keep_prob=1., is_training=None, scope=None, **kwargs ): use_deep = isinstance(deep_act, str) # use Two layers or Single layer for the alignment score with tf.variable_scope(scope or 's2t_self_attn_{}'.format(method)): tensor_shape = get_shape_list(tensor_input) hn = tensor_shape[-1] # hidden state number if method == 'additive': align_scores = bn_dense_layer_v2( # bs,sl,hn/1 tensor_input, hn if use_deep else 1, True, 0., 'align_score_1', 'linear', False, wd, keep_prob, is_training ) if use_deep: align_scores = bn_dense_layer_v2( # bs,sl,1 act_name2fn(deep_act)(align_scores), 1, True, 0., 'align_score_2', 'linear', False, wd, keep_prob, is_training ) elif method == 'multi_dim': align_scores = bn_dense_layer_v2( # bs,sl,hn tensor_input, hn, False, 0., 'align_score_1', 'linear', False, wd, keep_prob, is_training ) if use_deep: align_scores = bn_dense_layer_v2( # bs,sl,hn act_name2fn(deep_act)(align_scores), hn, True, 0., 'align_score_2', 'linear', False, wd, keep_prob, is_training ) elif method == 'multi_dim_head': get_shape_list(tensor_input, expected_rank=3) # the input should be rank-3 assert 'head_num' in kwargs and isinstance(kwargs['head_num'], int) head_num = kwargs['head_num'] assert hn % head_num == 0 head_dim = hn // head_num tensor_input_heads = split_head(tensor_input, head_num) # [bs,hd,sl,hd_dim] align_scores_heads = bn_dense_layer_multi_head( # [bs,hd,sl,hd_dim] tensor_input_heads, head_dim, True, 0., 'align_scores_heads_1', 'linear', False, wd, keep_prob, is_training ) if use_deep: align_scores_heads = bn_dense_layer_multi_head( # [bs,hd,sl,hd_dim] act_name2fn(deep_act)(align_scores_heads), head_dim, True, 0., 'align_scores_heads_2', 'linear', False, wd, keep_prob, is_training ) align_scores = combine_head(align_scores_heads) # [bs,sl,dim] else: raise AttributeError # attention procedure align_scores [bs,sl,1/dim] align_scores_masked = exp_mask_v3(align_scores, tensor_mask, multi_head=False, high_dim=True) # bs,sl,hn attn_prob = tf.nn.softmax(align_scores_masked, axis=-2) # bs,sl,hn if 'attn_keep_prob' in kwargs and isinstance(kwargs['attn_keep_prob'], float): attn_prob = dropout(attn_prob, kwargs['attn_keep_prob'], is_training) # bs,sl,hn attn_res = tf.reduce_sum( # [bs,sl,hn] -> [bs,dim] mask_v3(attn_prob*tensor_input, tensor_mask, high_dim=True), axis=-2 ) return attn_res # [bs,hn]
def cond_attn( pairwise_scores, featurewise_scores, value_features, from_mask, to_mask, attn_keep_prob=1., is_training=None, extra_pairwise_mask=None, name=None ): """ :param pairwise_scores: [bs,[head],slf,slt] :param featurewise_scores: [bs,[head],slt,hn] :param value_features: [bs,[head],slt,hn] :param from_mask: :param to_mask: :param extra_pairwise_mask: :return: """ with tf.name_scope(name or 'cond_attn'): # sanity check pairwise_shape = get_shape_list(pairwise_scores) featurewise_shape = get_shape_list(featurewise_scores) value_shape = get_shape_list(value_features) pairwise_ndim = len(pairwise_shape) featurewise_ndim = len(featurewise_shape) value_ndim = len(value_shape) assert featurewise_shape[-1] == value_shape[-1] assert pairwise_ndim in [3, 4] and pairwise_ndim == featurewise_ndim and featurewise_ndim == value_ndim multi_head = True if pairwise_ndim == 4 else False # if the multi-head included cross_attn_mask = cross_attn_mask_generation( # [bs,slf,slt] from_mask, to_mask, mutual=True ) if multi_head: # add the multi-head dim cross_attn_mask = tf.expand_dims(cross_attn_mask, 1) # [bs,[1],slf,slt] if not isinstance(extra_pairwise_mask, type(None)): # the extra_pairwise_mask could be include the multi-head extra_pairwise_mask_shape = get_shape_list(extra_pairwise_mask) assert len(extra_pairwise_mask_shape) in [3, 4] assert multi_head or len(extra_pairwise_mask_shape) == 3 # if multi_head=False, shape must be 3-D if multi_head and len(extra_pairwise_mask_shape) == 3: extra_pairwise_mask = tf.expand_dims(cross_attn_mask, 1) # [bs,[1],slf,slt] cross_attn_mask = tf.logical_and(cross_attn_mask, extra_pairwise_mask) # [bs,[1],slf,slt] e_dot_logits = mask_v3( # bs,head,sl1,sl2 tf.exp(pairwise_scores), cross_attn_mask, multi_head=False, high_dim=False) # the multi-head has been add e_multi_logits = mask_v3( tf.exp(featurewise_scores), to_mask, multi_head=multi_head, high_dim=True ) with tf.name_scope("hybrid_attn"): # Z: softmax normalization term in attention probabilities calculation accum_z_deno = tf.matmul(e_dot_logits, e_multi_logits) # num,bs,sl,dim accum_z_deno = tf.where( # in case of NaN and Inf tf.greater(accum_z_deno, tf.zeros_like(accum_z_deno)), accum_z_deno, tf.ones_like(accum_z_deno) ) # attention dropout e_dot_logits = dropout(e_dot_logits, math.sqrt(attn_keep_prob), is_training) e_multi_logits = dropout(e_multi_logits, math.sqrt(attn_keep_prob), is_training) # sum of exp(logits) \multiply attention target sequence rep_mul_score = value_features * e_multi_logits accum_rep_mul_score = tf.matmul(e_dot_logits, rep_mul_score) # calculate the final attention results attn_res = accum_rep_mul_score / accum_z_deno if multi_head: attn_res = combine_head(attn_res) # [bs,slf,hd_num*hd_dim] return attn_res # [bs,slf,hn/hd_num*hd_dim]