def logits_for_sketch_prediction(decoder_states, cls_state, num_channel, hn=None, act_name="relu", wd=0., keep_prob=1.0, is_train=None, compress_mask=None, scope=None): compressing = not isinstance(compress_mask, type(None)) hn = hn or get_shape_list(decoder_states)[-1] with tf.variable_scope(scope or "logits_for_sketch_index"): if compressing: new_decoder_states, _, rev_d = compress_seq_wrt_mask( decoder_states, compress_mask) else: new_decoder_states = decoder_states rev_d = None map_part1 = bn_dense_layer_v2(new_decoder_states, hn, True, 0., "map_part1", "linear", False, wd, keep_prob, is_train) map_part2_pre = bn_dense_layer_v2(cls_state, hn, False, 0., "map_part2_pre", "linear", False, wd, keep_prob, is_train) map_part2 = tf.tile(tf.expand_dims(map_part2_pre, axis=1), [1, get_shape_list(map_part1)[1], 1]) map_res = act_name2fn(act_name)(map_part1 + map_part2) logits = bn_dense_layer_v2(map_res, num_channel, True, 0., "logits", "linear", False, wd, keep_prob, is_train) if compressing: logits = decompress_seq_wrt_mask(logits, rev_d) return logits
def logits_for_sketch_index( decoder_states, encoder_states, hn=None, wd=0., keep_prob=1.0, is_train=None, compress_mask=None, scope=None, ): compressing = not isinstance(compress_mask, type(None)) hn = hn or get_shape_list(decoder_states)[-1] with tf.variable_scope(scope or "logits_for_sketch_index"): if compressing: new_decoder_states, _, rev_d = compress_seq_wrt_mask( decoder_states, compress_mask) else: new_decoder_states = decoder_states rev_d = None with tf.variable_scope("projection"): encoder_states_map = bn_dense_layer_v2(encoder_states, hn, True, 0., "encoder_states_map", "linear", False, wd, keep_prob, is_train) decoder_states_map = bn_dense_layer_v2(new_decoder_states, hn, True, 0., "decoder_states_map", "linear", False, wd, keep_prob, is_train) with tf.variable_scope("bi_linear"): bilinear_pre = bn_dense_layer_v2(decoder_states_map, hn, False, 0., "bilinear_map", "linear", False, wd, keep_prob, is_train) logits = tf.matmul(bilinear_pre, encoder_states_map, transpose_b=True) # bs,dsl,esl if compressing: logits = decompress_seq_wrt_mask(logits, rev_d) return logits
def transformer_seq_decoder(dec_input_emb_mat, decoder_ids, encoder_states, decoder_mask, encoder_mask, n_out_channel, num_layers, decoder_history_inputs=None, hn=768, head_num=12, act_name="gelu", wd=0., is_training=None, keep_prob_dense=1., keep_prob_attn=1., keep_prob_res=1., scope=None): with tf.variable_scope(scope or "transformer_seq_decoder"): with tf.variable_scope("decoder_emb"): decoder_inputs = tf.nn.embedding_lookup(dec_input_emb_mat, decoder_ids) # bs,sl,hn with tf.variable_scope("decoder_recurrence"): dec_outputs, new_decoder_history_inputs = transformer_decoder( # bs,sl,hn decoder_inputs, encoder_states, decoder_mask, encoder_mask, num_layers, decoder_history_inputs, hn, head_num, act_name, wd, is_training, keep_prob_dense, keep_prob_attn, keep_prob_res, scope="transformer_decoder") # prediction logits: two layer # pre_logits_seq2seq = bn_dense_layer_v2( # dec_outputs, hn, True, 0., "pre_logits_seq2seq", act_name, # False, 0., keep_prob_dense, is_training # ) logits_seq2seq = bn_dense_layer_v2( # bs,sl, dec_outputs, n_out_channel, True, 0., "logits_seq2seq", "linear", False, 0., keep_prob_dense, is_training) return dec_outputs, logits_seq2seq, new_decoder_history_inputs
def _build_network_seq_label_logits(self, encoder_states): wp_features = get_word_level_split( # bs,sl,hn -> bs,asl,pl,hn encoder_states, self.input_pos_ids, self.wordpiece_idx, self.input_mask, self.asl, self.pl ) all_token_features = s2t_self_attn( # bs,asl,hn wp_features, self.wordpiece_mask, self.cfg['clf_act_name'], 'multi_dim', 0., 1.-self.cfg['clf_dropout'], self.is_training, 'all_token_features', ) # get seq_label_token_features asl -> sll (asl-1) seq_label_token_features = mask_v3( # remove the latest feature all_token_features[:, :-1], self.seq_label_mask, high_dim=True ) with tf.variable_scope("output"): with tf.variable_scope("seq_labeling"): seq_label_logits = bn_dense_layer_v2( # "O" (NO PAD for predicate no empty no pad seq_label_token_features, 1 + (self.num_EO_labels-2) * (self.num_type_labels-2), True, 0., "seq_labeling_logits", "linear", False, 0., 1. - self.cfg['clf_dropout'], self.is_training ) return seq_label_logits
def attn_post_proc(attn_res, inter_hn=None, wd=0., keep_prob=1., residual_keep_prob=1., is_train=None, activation='relu', sparse_opt=False, scope=None, **kwargs): with tf.variable_scope(scope or "attn_res"): assert "mask" in kwargs if sparse_opt: x1, reverse_spec = masked_dense2sparse(attn_res, kwargs.get("mask")) else: x1 = attn_res y = bn_dense_layer_v2( x1, get_shape_list(attn_res)[-1], True, 0., "dense_layer", "linear", False, wd, keep_prob, is_train ) x2 = residual_connection(x1, y, is_train, residual_keep_prob, "res_con") res = residual_connection_with_dense( x2, inter_hn or 4*get_shape_list(attn_res)[-1], True, 0., "residual_connection_with_dense", activation, False, wd, keep_prob, is_train, residual_keep_prob ) if sparse_opt: res = masked_sparse2dense(res, reverse_spec) return res
def compatibility_fn_lacacy( # did not support arbitrary dim tensor_from, tensor_to, method='dot_product', scope=None, **kwargs): def _get_val_from_kwargs(key, default_val): if key in kwargs: return kwargs[key] else: return default_val with tf.variable_scope(scope or 'compatibility_fn.{}'.format(method)): shape_from = get_shape_list(tensor_from) ndim_from = len(shape_from) shape_to = get_shape_list(tensor_to) ndim_to = len(shape_to) assert (ndim_from == 2 or ndim_from == 3) and ndim_to == 3 if ndim_from == 2: tensor_from = tf.expand_dims(tensor_from, 1) shape_from = get_shape_list(tensor_from) slf, slt = shape_from[1], shape_to[1] # hparams parsing hn = _get_val_from_kwargs('hn', shape_to[-1]) wd = _get_val_from_kwargs('wd', 0.) keep_prob = _get_val_from_kwargs('keep_prob', 1.) is_training = _get_val_from_kwargs('is_training', None) activation = _get_val_from_kwargs('activation', 'relu') head_num = _get_val_from_kwargs('head_num', 12) seq_dim_to_remove = 1 if method == 'dot_product': align_scores = tf.matmul(tensor_from, tensor_to, transpose_b=True) # [bs,slf,hn]*[bs,slt,hn]=>bs,slf,slt align_scores = tf.expand_dims(align_scores, -1) # [bs,slf,slt,1] elif method == 'additive': tensor_from_branch = bn_dense_layer_v2( tensor_from, hn, False, 0., 'tensor_from_branch', 'linear', False, wd, keep_prob, is_training ) tensor_to_branch = bn_dense_layer_v2( tensor_to, hn, True, 0., 'tensor_to_branch', 'linear', False, wd, keep_prob, is_training ) align_scores_pre = act_name2fn(activation)(tf.add( # [bs,slf,slt,hn] tf.expand_dims(tensor_from_branch, 2), # [bs,slf,1,hn] tf.expand_dims(tensor_to_branch, 1) # [bs,1,slt,hn] )) align_scores = bn_dense_layer_v2( # [bs,slf,slt,1] align_scores_pre, 1, True, 0., 'align_scores', 'linear', False, wd, keep_prob, is_training ) elif method == 'multi_dim': logging.warning("No simplified multi-dim technique used in this function!") tensor_from_branch = bn_dense_layer_v2( tensor_from, hn, False, 0., 'tensor_from_branch', 'linear', False, wd, keep_prob, is_training ) tensor_to_branch = bn_dense_layer_v2( tensor_to, hn, True, 0., 'tensor_to_branch', 'linear', False, wd, keep_prob, is_training ) align_scores_pre = act_name2fn(activation)(tf.add( # [bs,slf,slt,hn] tf.expand_dims(tensor_from_branch, 2), # [bs,slf,1,hn] tf.expand_dims(tensor_to_branch, 1) # bs,1,slt,hn )) align_scores = bn_dense_layer_v2( align_scores_pre, hn, True, 0., 'align_score', 'linear', False, wd, keep_prob, is_training ) elif method == 'multi_head': seq_dim_to_remove = 2 # !!! because multi-head dim is on 2nd dim assert hn % head_num == 0 head_dim = hn // head_num q_heads = bn_dense_layer_v2( tensor_from, head_dim, True, 0., 'q_heads', 'linear', False, wd, keep_prob, is_training, dup_num=head_num ) k_heads = bn_dense_layer_v2( tensor_to, head_dim, True, 0., 'k_heads', 'linear', False, wd, keep_prob, is_training, dup_num=head_num ) q_heads = split_head(q_heads, head_num) # bs,hd_num,slf,hd_dim k_heads = split_head(k_heads, head_num) # bs,hd_num,slt,hd_dim # alignment score align_scores = tf.matmul(q_heads, k_heads, transpose_b=True) # [bs,hd_num,slf,slt] align_scores = align_scores / math.sqrt(1.*head_dim) # [bs,hd_num,slf,slt] elif method == 'multi_dim_head': seq_dim_to_remove = 2 # !!! because multi-head dim is on 2nd dim assert hn % head_num == 0 head_dim = hn // head_num q_heads = bn_dense_layer_v2( tensor_from, head_dim, True, 0., 'q_heads', 'linear', False, wd, keep_prob, is_training, dup_num=head_num ) k_heads = bn_dense_layer_v2( tensor_to, head_dim, True, 0., 'k_heads', 'linear', False, wd, keep_prob, is_training, dup_num=head_num ) q_heads = split_head(q_heads, head_num) # bs,hd_num,slf,hd_dim k_heads = split_head(k_heads, head_num) # bs,hd_num,slt,hd_dim # MLP q_heads_branch = bn_dense_layer_multi_head( q_heads, head_dim, False, 0., 'q_heads_branch', 'linear', False, wd, keep_prob, is_training ) k_heads_branch = bn_dense_layer_multi_head( k_heads, head_dim, True, 0., 'k_heads_branch', 'linear', False, wd, keep_prob, is_training ) align_scores_pre = act_name2fn(activation)(tf.add( # [bs,head,slf,slt,dim] tf.expand_dims(q_heads_branch, 3), # [bs,head,slf,1,dim] tf.expand_dims(k_heads_branch, 2) # bs,head,1,slt,dim )) align_scores_heads = bn_dense_layer_multi_head( # [bs,hd_num,slf,slt,hd_dim] align_scores_pre, head_dim, True, 0., 'align_scores_heads', 'linear', False, wd, keep_prob, is_training ) align_scores = align_scores_heads # [bs,hd_num,slf,slt,hd_dim] # align_scores = combine_head(align_scores_heads) elif method == 'bilinear': raise NotImplementedError else: raise AttributeError if ndim_from == 2: align_scores = tf.squeeze(align_scores, [seq_dim_to_remove]) # return align_scores
def s2t_self_attn( tensor_input, tensor_mask, deep_act=None, method='multi_dim', wd=0., keep_prob=1., is_training=None, scope=None, **kwargs ): use_deep = isinstance(deep_act, str) # use Two layers or Single layer for the alignment score with tf.variable_scope(scope or 's2t_self_attn_{}'.format(method)): tensor_shape = get_shape_list(tensor_input) hn = tensor_shape[-1] # hidden state number if method == 'additive': align_scores = bn_dense_layer_v2( # bs,sl,hn/1 tensor_input, hn if use_deep else 1, True, 0., 'align_score_1', 'linear', False, wd, keep_prob, is_training ) if use_deep: align_scores = bn_dense_layer_v2( # bs,sl,1 act_name2fn(deep_act)(align_scores), 1, True, 0., 'align_score_2', 'linear', False, wd, keep_prob, is_training ) elif method == 'multi_dim': align_scores = bn_dense_layer_v2( # bs,sl,hn tensor_input, hn, False, 0., 'align_score_1', 'linear', False, wd, keep_prob, is_training ) if use_deep: align_scores = bn_dense_layer_v2( # bs,sl,hn act_name2fn(deep_act)(align_scores), hn, True, 0., 'align_score_2', 'linear', False, wd, keep_prob, is_training ) elif method == 'multi_dim_head': get_shape_list(tensor_input, expected_rank=3) # the input should be rank-3 assert 'head_num' in kwargs and isinstance(kwargs['head_num'], int) head_num = kwargs['head_num'] assert hn % head_num == 0 head_dim = hn // head_num tensor_input_heads = split_head(tensor_input, head_num) # [bs,hd,sl,hd_dim] align_scores_heads = bn_dense_layer_multi_head( # [bs,hd,sl,hd_dim] tensor_input_heads, head_dim, True, 0., 'align_scores_heads_1', 'linear', False, wd, keep_prob, is_training ) if use_deep: align_scores_heads = bn_dense_layer_multi_head( # [bs,hd,sl,hd_dim] act_name2fn(deep_act)(align_scores_heads), head_dim, True, 0., 'align_scores_heads_2', 'linear', False, wd, keep_prob, is_training ) align_scores = combine_head(align_scores_heads) # [bs,sl,dim] else: raise AttributeError # attention procedure align_scores [bs,sl,1/dim] align_scores_masked = exp_mask_v3(align_scores, tensor_mask, multi_head=False, high_dim=True) # bs,sl,hn attn_prob = tf.nn.softmax(align_scores_masked, axis=-2) # bs,sl,hn if 'attn_keep_prob' in kwargs and isinstance(kwargs['attn_keep_prob'], float): attn_prob = dropout(attn_prob, kwargs['attn_keep_prob'], is_training) # bs,sl,hn attn_res = tf.reduce_sum( # [bs,sl,hn] -> [bs,dim] mask_v3(attn_prob*tensor_input, tensor_mask, high_dim=True), axis=-2 ) return attn_res # [bs,hn]
def compatibility_fn(tensor_from, tensor_to, method='dot_product', scope=None, **kwargs): def _get_val_from_kwargs(key, default_val): if key in kwargs: return kwargs[key] else: return default_val with tf.variable_scope(scope or 'compatibility_fn.{}'.format(method)): shape_from = get_shape_list(tensor_from) ndim_from = len(shape_from) shape_to = get_shape_list(tensor_to) ndim_to = len(shape_to) assert ndim_from == ndim_to or ndim_from+1 == ndim_to need_extra_dim = ndim_from+1 == ndim_to if need_extra_dim: tensor_from = tf.expand_dims(tensor_from, -2) shape_from = get_shape_list(tensor_from) slf, slt = shape_from[-2], shape_to[-2] # hparams parsing hn = _get_val_from_kwargs('hn', shape_to[-1]) wd = _get_val_from_kwargs('wd', 0.) keep_prob = _get_val_from_kwargs('keep_prob', 1.) is_training = _get_val_from_kwargs('is_training', None) activation = _get_val_from_kwargs('activation', 'relu') head_num = _get_val_from_kwargs('head_num', 12) seq_dim_to_remove = -3 if method == 'dot_product': align_scores = tf.matmul(tensor_from, tensor_to, transpose_b=True) # [bs,slf,hn]*[bs,slt,hn]=>bs,slf,slt align_scores = tf.expand_dims(align_scores, -1) # [bs,slf,slt,1] elif method == 'additive': tensor_from_branch = bn_dense_layer_v2( tensor_from, hn, False, 0., 'tensor_from_branch', 'linear', False, wd, keep_prob, is_training ) tensor_to_branch = bn_dense_layer_v2( tensor_to, hn, True, 0., 'tensor_to_branch', 'linear', False, wd, keep_prob, is_training ) align_scores_pre = act_name2fn(activation)(tf.add( # [bs,slf,slt,hn] tf.expand_dims(tensor_from_branch, -2), # [bs,slf,1,hn] tf.expand_dims(tensor_to_branch, -3) # [bs,1,slt,hn] )) align_scores = bn_dense_layer_v2( # [bs,slf,slt,1] align_scores_pre, 1, True, 0., 'align_scores', 'linear', False, wd, keep_prob, is_training ) elif method == 'multi_dim': logging.warning("No simplified multi-dim technique used in this function!") tensor_from_branch = bn_dense_layer_v2( tensor_from, hn, False, 0., 'tensor_from_branch', 'linear', False, wd, keep_prob, is_training ) tensor_to_branch = bn_dense_layer_v2( tensor_to, hn, True, 0., 'tensor_to_branch', 'linear', False, wd, keep_prob, is_training ) align_scores_pre = act_name2fn(activation)(tf.add( # [bs,slf,slt,hn] tf.expand_dims(tensor_from_branch, -2), # [bs,slf,1,hn] tf.expand_dims(tensor_to_branch, -3) # bs,1,slt,hn )) align_scores = bn_dense_layer_v2( # [bs,slf,slt,hn] align_scores_pre, hn, True, 0., 'align_score', 'linear', False, wd, keep_prob, is_training ) elif method == 'multi_head': seq_dim_to_remove = -2 # !!! because multi-head dim is on 2nd dim assert hn % head_num == 0 head_dim = hn // head_num q_heads = bn_dense_layer_v2( tensor_from, head_dim, True, 0., 'q_heads', 'linear', False, wd, keep_prob, is_training, dup_num=head_num ) k_heads = bn_dense_layer_v2( tensor_to, head_dim, True, 0., 'k_heads', 'linear', False, wd, keep_prob, is_training, dup_num=head_num ) q_heads = split_head(q_heads, head_num) # bs,hd_num,slf,hd_dim k_heads = split_head(k_heads, head_num) # bs,hd_num,slt,hd_dim # alignment score align_scores = tf.matmul(q_heads, k_heads, transpose_b=True) # [bs,hd_num,slf,slt] align_scores = align_scores / math.sqrt(1.*head_dim) # [bs,hd_num,slf,slt] elif method in ['multi_head_bilinear', 'multi_head_bilinear_shared', 'multi_head_only', 'multi_head_linear']: seq_dim_to_remove = -2 # !!! because multi-head dim is on 2nd dim assert hn % head_num == 0 head_dim = hn // head_num q_heads = bn_dense_layer_v2( tensor_from, head_dim, True, 0., 'q_heads', kwargs.get("activation") or activation, False, wd, keep_prob, is_training, dup_num=head_num ) k_heads = bn_dense_layer_v2( tensor_to, head_dim, True, 0., 'k_heads', kwargs.get("activation") or activation, False, wd, keep_prob, is_training, dup_num=head_num ) q_heads = split_head(q_heads, head_num) # bs,hd_num,slf,hd_dim k_heads = split_head(k_heads, head_num) # bs,hd_num,slt,hd_dim # alignment score: using biliear rather than dot product # align_scores = tf.matmul(q_heads, k_heads, transpose_b=True) # [bs,hd_num,slf,slt] # align_scores = align_scores / math.sqrt(1. * head_dim) # [bs,hd_num,slf,slt] with tf.variable_scope("bilinear"): if method == "multi_head_bilinear": k_heads_map = bn_dense_layer_multi_head( k_heads, head_dim, False, 0., 'k_heads_map', 'linear', False, wd, keep_prob, is_training) elif method == "multi_head_bilinear_shared": k_heads_map = bn_dense_layer_v2( k_heads, head_dim, False, 0., 'k_heads_map', 'linear', False, wd, keep_prob, is_training) elif method == "multi_head_only": pass elif method == "multi_head_linear": k_heads_map = bn_dense_layer_v2( k_heads, head_dim, False, 0., 'k_heads_map', 'linear', False, wd, keep_prob, is_training) q_heads_map = bn_dense_layer_v2( q_heads, head_dim, False, 0., 'q_heads_map', 'linear', False, wd, keep_prob, is_training) else: raise AttributeError align_scores = tf.matmul(q_heads, k_heads, transpose_b=True) log_specific_params() elif method == 'multi_dim_head': assert hn % head_num == 0 head_dim = hn // head_num q_heads = bn_dense_layer_v2( tensor_from, head_dim, True, 0., 'q_heads', 'linear', False, wd, keep_prob, is_training, dup_num=head_num ) k_heads = bn_dense_layer_v2( tensor_to, head_dim, True, 0., 'k_heads', 'linear', False, wd, keep_prob, is_training, dup_num=head_num ) q_heads = split_head(q_heads, head_num) # bs,hd_num,slf,hd_dim k_heads = split_head(k_heads, head_num) # bs,hd_num,slt,hd_dim # MLP q_heads_branch = bn_dense_layer_multi_head( q_heads, head_dim, False, 0., 'q_heads_branch', 'linear', False, wd, keep_prob, is_training ) k_heads_branch = bn_dense_layer_multi_head( k_heads, head_dim, True, 0., 'k_heads_branch', 'linear', False, wd, keep_prob, is_training ) align_scores_pre = act_name2fn(activation)(tf.add( # [bs,head,slf,slt,dim] tf.expand_dims(q_heads_branch, -2), # [bs,head,slf,1,dim] tf.expand_dims(k_heads_branch, -3) # bs,head,1,slt,dim )) align_scores_heads = bn_dense_layer_multi_head( # [bs,hd_num,slf,slt,hd_dim] align_scores_pre, head_dim, True, 0., 'align_scores_heads', 'linear', False, wd, keep_prob, is_training ) align_scores = align_scores_heads # [bs,hd_num,slf,slt,hd_dim] elif method == 'bilinear': raise NotImplementedError else: raise AttributeError if need_extra_dim: align_scores = tf.squeeze(align_scores, [seq_dim_to_remove]) # return align_scores
def multihead_attention_decoder( tensor_from, tensor_to, mask_to, mask_direction=None, # [bs,slf,slt] act_name="relu", hn=768, head_num=12, wd=0., is_training=None, keep_prob_dense=1., keep_prob_attn=1., tensor_to_prev=None, mask_prev_to=None, scope=None, ): head_dim = hn // head_num with tf.variable_scope(scope or "multihead_attention_decoder"): # if not isinstance(tensor_to_prev, type(None)): # to print the shape # tensor_from = tf.Print(tensor_from, [ # tf.shape(tensor_from), tf.shape(tensor_to), tf.shape(mask_to), tf.shape(tensor_to_prev)]) if isinstance(tensor_to_prev, type(None)): tensor_to_all = tensor_to # bs,sl,hn mask_to_all = mask_to # bs,sl else: tensor_to_all = tf.concat([tensor_to_prev, tensor_to], -2) # bs,psl+1,hn if mask_prev_to is None: mask_prev_to = tf.cast( tf.ones(get_shape_list(tensor_to_prev, 3)[:2], tf.int32), tf.bool) # bs,psl mask_to_all = tf.concat([mask_prev_to, mask_to], -1) # bs,psl+1 attn_scores = compatibility_fn( tensor_from, tensor_to_all, method="multi_head", head_num=head_num, hn=hn, wd=wd, is_training=is_training, keep_prob=keep_prob_dense, ) # [bs,hd_num,slf,slt] v_heads = bn_dense_layer_v2( # bs,slt,hd_dim * hd_num tensor_to_all, head_dim, True, 0., 'v_heads', 'linear', False, wd, keep_prob_dense, is_training, dup_num=head_num) v_heads = split_head(v_heads, head_num) # # bs,hd_num,slt,hd_dim # mask the self-attention scores attn_scores_mask = tf.expand_dims(mask_to_all, 1) # bs,1,tsl if (not isinstance(mask_direction, type(None))) and isinstance( tensor_to_prev, type(None)): attn_scores_mask = tf.logical_and(attn_scores_mask, mask_direction) # bs,tsl,tsl attn_scores_masked = exp_mask_v3( attn_scores, attn_scores_mask, multi_head=True) # [bs,hd_num,slf,slt] attn_prob = tf.nn.softmax(attn_scores_masked) attn_prob = dropout(attn_prob, keep_prob_attn, is_training) # [bs,hd_num,slf,slt] v_heads_etd = tf.expand_dims(v_heads, 2) # bs,hd_num,1,slt,hd_dim attn_prob_etd = tf.expand_dims(attn_prob, -1) # bs,hd_num,slf,slt,1 attn_res = tf.reduce_sum(v_heads_etd * attn_prob_etd, 3) # bs,hd_num,slf,hd_dim out_prev = combine_head(attn_res) # bs,fsl,hn # if mask_direction is not None and tensor_to_prev is None: # attn_scores = exp_mask_v3(attn_scores, mask_direction, multi_head=True) # [bs,hd_num,slf,slt] # attn_scores = dropout(attn_scores, keep_prob_attn, is_training) # # attn_res = softsel( # [bs,hd_num,slf,dhn] # v_heads, attn_scores, mask_to_all, # mask_add_head_dim_for_scores=True, # input_add_multi_head_dim=False, # score_add_hn_dim=True, # axis=3) # out_prev = combine_head(attn_res) # dense layer out = bn_dense_layer_v2(out_prev, hn, True, 0., "output_transformer", act_name, False, wd, keep_prob_dense, is_training) return out