def transformer_decoder(dec_input, enc_output, dec_slf_attn_bias, dec_enc_attn_bias, n_layer, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, hidden_act, preprocess_cmd, postprocess_cmd, caches=None, gather_idx=None, param_initializer=None, name='transformer_decoder'): """ The decoder is composed of a stack of identical decoder_layer layers. :param dec_input: (batch_size, tgt_len, emb_dim) :param enc_output: (batch_size, n_tokens, emb_dim) :param dec_slf_attn_bias: (batch_size, n_head, tgt_len, tgt_len) :param dec_enc_attn_bias: (batch_size, n_head, tgt_len, n_tokens) """ for i in range(n_layer): # (batch_size, tgt_len, emb_dim) dec_output = transformer_decoder_layer( dec_input=dec_input, enc_output=enc_output, slf_attn_bias=dec_slf_attn_bias, dec_enc_attn_bias=dec_enc_attn_bias, n_head=n_head, d_key=d_key, d_value=d_value, d_model=d_model, d_inner_hid=d_inner_hid, prepostprocess_dropout=prepostprocess_dropout, attention_dropout=attention_dropout, relu_dropout=relu_dropout, hidden_act=hidden_act, preprocess_cmd=preprocess_cmd, postprocess_cmd=postprocess_cmd, cache=None if caches is None else caches[i], gather_idx=gather_idx, param_initializer=param_initializer, name=name + '_layer_' + str(i)) dec_input = dec_output # add layer normalization dec_output = pre_process_layer(out=dec_output, process_cmd=preprocess_cmd, dropout_rate=prepostprocess_dropout, name=name + '_post') return dec_output # (batch_size, tgt_len, emb_dim)
def transformer_encoder(enc_input, attn_bias, n_layer, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, hidden_act, preprocess_cmd="n", postprocess_cmd="da", param_initializer=None, name='transformer_encoder', with_post_process=True): """ The encoder is composed of a stack of identical layers returned by calling encoder_layer. """ for i in range(n_layer): enc_output = transformer_encoder_layer( enc_input, None, attn_bias, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, hidden_act, preprocess_cmd, postprocess_cmd, param_initializer=param_initializer, name=name + '_layer_' + str(i)) enc_input = enc_output if with_post_process: enc_output = pre_process_layer(enc_output, preprocess_cmd, prepostprocess_dropout, name="post_encoder") return enc_output
def self_attention_pooling_layer(enc_input, attn_bias, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, n_block, preprocess_cmd="n", postprocess_cmd="da", name='self_attention_pooling'): """ enc_input: # (batch_size*n_blocks, n_tokens, emb_dim) attn_bias: # (batch_size*n_blocks, n_head, n_tokens, n_tokens) """ attn_output = multi_head_pooling( keys=pre_process_layer(enc_input, preprocess_cmd, prepostprocess_dropout, name=name + '_pre'), # add layer normalization values=None, attn_bias=attn_bias, # (batch_size*n_blocks, n_head, n_tokens, n_tokens) d_value=d_value, d_model=d_model, n_head=n_head, dropout_rate=attention_dropout, name=name) # (batch_size*n_blocks, d_model) # print("n_block = %s" % n_block) # print("attn_output.shape = %s" % str(attn_output.shape)) attn_output = layers.reshape(attn_output, shape=[-1, n_block, d_model]) # print("attn_output.shape = %s" % str(attn_output.shape)) pooling_output = layers.dropout(attn_output, dropout_prob=attention_dropout, dropout_implementation="upscale_in_train", is_test=False) return pooling_output
def encode(self, enc_input): """Encoding the source input""" src_word, src_word_pos, src_sen_pos, src_words_slf_attn_bias, \ src_sents_slf_attn_bias, graph_attn_bias = enc_input enc_res = self._gen_enc_input(src_word, src_word_pos, src_sen_pos, src_words_slf_attn_bias, src_sents_slf_attn_bias, graph_attn_bias) emb_out, src_words_slf_attn_bias, src_sents_slf_attn_bias, graph_attn_bias = \ enc_res.emb_out, enc_res.word_slf_attn_bias, enc_res.sen_slf_attn_bias, enc_res.graph_attn_bias # (batch_size*n_blocks, n_tokens, emb_dim) emb_out = layers.reshape(emb_out, shape=[-1, self.max_para_len, self._emb_size]) # (batch_size*n_block, n_head, n_tokens, n_tokens) src_words_slf_attn_bias = layers.reshape( src_words_slf_attn_bias, shape=[-1, self._n_head, self.max_para_len, self.max_para_len]) # the token-level transformer encoder # (batch_size*n_blocks, n_tokens, emb_dim) enc_words_out = transformer_encoder( enc_input=emb_out, attn_bias=src_words_slf_attn_bias, n_layer=self._enc_word_layer, n_head=self._n_head, d_key=self._emb_size // self._n_head, d_value=self._emb_size // self._n_head, d_model=self._emb_size, d_inner_hid=self._emb_size * 4, prepostprocess_dropout=self._prepostprocess_dropout, attention_dropout=self._attention_dropout, relu_dropout=self._prepostprocess_dropout, hidden_act=self._hidden_act, preprocess_cmd=self._preprocess_command, postprocess_cmd=self._postprocess_command, param_initializer=self._param_initializer, name='transformer_encoder', with_post_process=False) # the paragraph-level graph encoder # (batch_size, n_block, emb_dim) enc_sents_out = graph_encoder( enc_words_output= enc_words_out, # (batch_size*n_blocks, n_tokens, emb_dim) src_words_slf_attn_bias= src_words_slf_attn_bias, # (batch_size*max_nblock, n_head, max_ntoken, max_ntoken) src_sents_slf_attn_bias= src_sents_slf_attn_bias, # (batch_size, n_head, max_nblock, max_nblock) graph_attn_bias= graph_attn_bias, # (batch_size, n_head, max_nblock, max_nblock) pos_win=self.pos_win, graph_layers=self._enc_graph_layer, n_head=self._n_head, d_key=self._emb_size // self._n_head, d_value=self._emb_size // self._n_head, d_model=self._emb_size, d_inner_hid=self._emb_size * 4, prepostprocess_dropout=self._prepostprocess_dropout, attention_dropout=self._attention_dropout, relu_dropout=self._prepostprocess_dropout, hidden_act=self._hidden_act, # n_block=self.max_para_num, preprocess_cmd=self._preprocess_command, postprocess_cmd=self._postprocess_command, param_initializer=self._param_initializer, name='graph_encoder') enc_words_out = pre_process_layer(enc_words_out, self._preprocess_command, self._prepostprocess_dropout, name="post_encoder") enc_words_out = layers.reshape( enc_words_out, shape=[-1, self.max_para_num, self.max_para_len, self._emb_size]) return enc_words_out, enc_sents_out
def transformer_encoder_layer(query_input, key_input, attn_bias, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, hidden_act, preprocess_cmd="n", postprocess_cmd="da", param_initializer=None, name=''): """The encoder layers that can be stacked to form a deep encoder. This module consits of a multi-head (self) attention followed by position-wise feed-forward networks and both the two components companied with the post_process_layer to add residual connection, layer normalization and droput. """ key_input = pre_process_layer(key_input, preprocess_cmd, prepostprocess_dropout, name=name + '_pre_att') if key_input else None value_input = key_input if key_input else None attn_output = multi_head_attention(pre_process_layer( query_input, preprocess_cmd, prepostprocess_dropout, name=name + '_pre_att'), key_input, value_input, attn_bias, d_key, d_value, d_model, n_head, attention_dropout, param_initializer=param_initializer, name=name + '_multi_head_att') attn_output = post_process_layer(query_input, attn_output, postprocess_cmd, prepostprocess_dropout, name=name + '_post_att') ffd_output = positionwise_feed_forward(pre_process_layer( attn_output, preprocess_cmd, prepostprocess_dropout, name=name + '_pre_ffn'), d_inner_hid, d_model, relu_dropout, hidden_act, param_initializer=param_initializer, name=name + '_ffn') return post_process_layer(attn_output, ffd_output, postprocess_cmd, prepostprocess_dropout, name=name + '_post_ffn')
def graph_encoder(enc_words_output, src_words_slf_attn_bias, src_sents_slf_attn_bias, graph_attn_bias, pos_win, graph_layers, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, hidden_act, preprocess_cmd="n", postprocess_cmd="da", param_initializer=None, name='graph_encoder'): """ :param enc_words_output: # (batch_size*n_blocks, n_tokens, emb_dim) :param src_words_slf_attn_bias: (batch_size*n_block, n_head, n_tokens, n_tokens) :param src_sents_slf_attn_bias: (batch_size, n_head, n_block, n_block) :param graph_attn_bias: (batch_size, n_head, n_block, n_block) :return: """ # (batch_size, n_block, d_model) sents_vec = self_attention_pooling_layer( enc_input=enc_words_output, attn_bias=src_words_slf_attn_bias, n_head=n_head, d_key=d_key, d_value=d_value, d_model=d_model, d_inner_hid=d_inner_hid, prepostprocess_dropout=prepostprocess_dropout, attention_dropout=attention_dropout, relu_dropout=relu_dropout, n_block=src_sents_slf_attn_bias.shape[2], preprocess_cmd="n", postprocess_cmd="da", name=name + '_pooling') enc_input = sents_vec # (batch_size, n_block, d_model) for i in range(graph_layers): # (batch_size, n_block, emb_dim) enc_output = graph_encoder_layer( enc_input=enc_input, # (batch_size, n_block, emb_dim) attn_bias= src_sents_slf_attn_bias, # (batch_size, n_head, n_block, n_block) graph_attn_bias= graph_attn_bias, # (batch_size, n_head, n_block, n_block) pos_win=pos_win, n_head=n_head, d_key=d_key, d_value=d_value, d_model=d_model, d_inner_hid=d_inner_hid, prepostprocess_dropout=prepostprocess_dropout, attention_dropout=attention_dropout, relu_dropout=relu_dropout, hidden_act=hidden_act, preprocess_cmd=preprocess_cmd, postprocess_cmd=postprocess_cmd, param_initializer=param_initializer, name=name + '_layer_' + str(i)) enc_input = enc_output # (batch_size, n_block, emb_dim) # add layer normalization enc_output = pre_process_layer(out=enc_output, process_cmd=preprocess_cmd, dropout_rate=prepostprocess_dropout, name=name + '_post') return enc_output # (batch_size, n_block, emb_dim)
def graph_encoder_layer( enc_input, # (batch_size, n_block, emb_dim) attn_bias, # (batch_size, n_head, n_block, n_block) graph_attn_bias, # (batch_size, n_head, n_block, n_block) pos_win, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, hidden_act, preprocess_cmd="n", postprocess_cmd="da", param_initializer=None, name=''): """ :param enc_input: (batch_size, n_blocks, emb_dim) :param attn_bias: (batch_size, n_head, n_blocks, n_blocks) :param graph_attn_bias: (batch_size, n_head, n_blocks, n_blocks) """ # (batch_size, n_block, d_model) attn_output = multi_head_structure_attention( queries=pre_process_layer( out=enc_input, # add layer normalization process_cmd=preprocess_cmd, dropout_rate=prepostprocess_dropout, name=name + '_pre_attn'), keys=None, values=None, attn_bias=attn_bias, graph_attn_bias=graph_attn_bias, pos_win=pos_win, d_key=d_key, d_value=d_value, d_model=d_model, n_head=n_head, dropout_rate=attention_dropout, name=name + '_graph_attn') # add dropout and residual connection attn_output = post_process_layer(prev_out=enc_input, out=attn_output, process_cmd=postprocess_cmd, dropout_rate=prepostprocess_dropout, name=name + '_post_attn') ffd_output = positionwise_feed_forward( x=pre_process_layer( out=attn_output, # add layer normalization process_cmd=preprocess_cmd, dropout_rate=prepostprocess_dropout, name=name + '_pre_ffn'), d_inner_hid=d_inner_hid, d_hid=d_model, dropout_rate=relu_dropout, hidden_act=hidden_act, param_initializer=param_initializer, name=name + '_ffn') return post_process_layer( prev_out=attn_output, # add dropout and residual connection out=ffd_output, process_cmd=postprocess_cmd, dropout_rate=prepostprocess_dropout, name=name + '_post_ffn')
def transformer_decoder_layer(dec_input, enc_output, slf_attn_bias, dec_enc_attn_bias, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, hidden_act, preprocess_cmd, postprocess_cmd, cache=None, gather_idx=None, param_initializer=None, name=''): """ The layer to be stacked in decoder part. :param dec_input: (batch_size, tgt_len, emb_dim) :param enc_output: (batch_size, n_tokens, emb_dim) :param slf_attn_bias: (batch_size, n_head, tgt_len, tgt_len) :param dec_enc_attn_bias: (batch_size, n_head, tgt_len, n_tokens) """ # (batch_size, tgt_len, emb_dim) slf_attn_output = multi_head_attention( queries=pre_process_layer( out=dec_input, # add layer normalization process_cmd=preprocess_cmd, dropout_rate=prepostprocess_dropout, name=name + '_pre_slf_attn'), keys=None, values=None, attn_bias=slf_attn_bias, # (batch_size, n_head, tgt_len, tgt_len) d_key=d_key, d_value=d_value, d_model=d_model, n_head=n_head, dropout_rate=attention_dropout, cache=cache, gather_idx=gather_idx, param_initializer=param_initializer, name=name + '_slf_attn') # add dropout and residual connection # (batch_size, tgt_len, emb_dim) slf_attn_output = post_process_layer(prev_out=dec_input, out=slf_attn_output, process_cmd=postprocess_cmd, dropout_rate=prepostprocess_dropout, name=name + '_post_slf_attn') # (batch_size, tgt_len, emb_dim) context_attn_output = multi_head_attention( queries=pre_process_layer( out=slf_attn_output, # add layer normalization process_cmd=preprocess_cmd, dropout_rate=prepostprocess_dropout, name=name + '_pre_context_attn'), keys=enc_output, # (batch_size, n_tokens, emb_dim) values=enc_output, # (batch_size, n_tokens, emb_dim) attn_bias=dec_enc_attn_bias, # (batch_size, n_head, tgt_len, n_tokens) d_key=d_key, d_value=d_value, d_model=d_model, n_head=n_head, dropout_rate=attention_dropout, cache=cache, gather_idx=gather_idx, static_kv=True, param_initializer=param_initializer, name=name + '_context_attn') # add dropout and residual connection context_attn_output = post_process_layer( prev_out=slf_attn_output, out=context_attn_output, process_cmd=postprocess_cmd, dropout_rate=prepostprocess_dropout, name=name + '_post_context_attn') ffd_output = positionwise_feed_forward( x=pre_process_layer( out=context_attn_output, # add layer normalization process_cmd=preprocess_cmd, dropout_rate=prepostprocess_dropout, name=name + '_pre_ffn'), d_inner_hid=d_inner_hid, d_hid=d_model, dropout_rate=relu_dropout, hidden_act=hidden_act, param_initializer=param_initializer, name=name + '_ffn') # add dropout and residual connection dec_output = post_process_layer(prev_out=context_attn_output, out=ffd_output, process_cmd=postprocess_cmd, dropout_rate=prepostprocess_dropout, name=name + '_post_ffn') return dec_output # (batch_size, tgt_len, emb_dim)
def graph_decoder_layer(dec_input, enc_words_output, enc_sents_output, slf_attn_bias, dec_enc_words_attn_bias, dec_enc_sents_attn_bias, graph_attn_bias, pos_win, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, hidden_act, preprocess_cmd, postprocess_cmd, cache=None, gather_idx=None, param_initializer=None, name=''): """ The layer to be stacked in decoder part. :param dec_input: (batch_size, tgt_len, emb_dim) :param enc_words_output: (batch_size, n_blocks, n_tokens, emb_dim) :param enc_sents_output: (batch_size, n_blocks, emb_dim) :param slf_attn_bias: (batch_size, n_head, tgt_len, tgt_len) :param dec_enc_words_attn_bias: (batch_size, n_blocks, n_head, tgt_len, n_tokens) :param dec_enc_sents_attn_bias: (batch_size, n_head, tgt_len, n_blocks) :param graph_attn_bias: (batch_size, n_head, n_blocks, n_blocks) """ # (batch_size, tgt_len, emb_dim) slf_attn_output = multi_head_attention( queries=pre_process_layer( out=dec_input, # add layer normalization process_cmd=preprocess_cmd, dropout_rate=prepostprocess_dropout, name=name + '_pre_attn'), keys=None, values=None, attn_bias=slf_attn_bias, # (batch_size, n_head, tgt_len, tgt_len) d_key=d_key, d_value=d_value, d_model=d_model, n_head=n_head, dropout_rate=attention_dropout, cache=cache, gather_idx=gather_idx, name=name + '_attn') # add dropout and residual connection # (batch_size, tgt_len, emb_dim) slf_attn_output = post_process_layer(prev_out=dec_input, out=slf_attn_output, process_cmd=postprocess_cmd, dropout_rate=prepostprocess_dropout, name=name + '_post_attn') # (batch_size, tgt_len, emb_dim) hier_attn_output = multi_head_hierarchical_attention( queries=pre_process_layer( out=slf_attn_output, # add layer normalization process_cmd=preprocess_cmd, dropout_rate=prepostprocess_dropout, name=name + '_pre_hier_attn'), keys_w=enc_words_output, # (batch_size, n_blocks, n_tokens, emb_dim) values_w=enc_words_output, # (batch_size, n_blocks, n_tokens, emb_dim) attn_bias_w= dec_enc_words_attn_bias, # (batch_size, n_blocks, n_head, tgt_len, n_tokens) keys_s=enc_sents_output, # (batch_size, n_blocks, emb_dim) values_s=enc_sents_output, # (batch_size, n_blocks, emb_dim) attn_bias_s= dec_enc_sents_attn_bias, # (batch_size, n_head, tgt_len, n_blocks) graph_attn_bias= graph_attn_bias, # (batch_size, n_head, n_blocks, n_blocks) pos_win=pos_win, d_key=d_key, d_value=d_value, d_model=d_model, n_head=n_head, dropout_rate=attention_dropout, cache=cache, gather_idx=gather_idx, name=name + '_hier_attn') # add dropout and residual connection hier_attn_output = post_process_layer(prev_out=slf_attn_output, out=hier_attn_output, process_cmd=postprocess_cmd, dropout_rate=prepostprocess_dropout, name=name + '_post_hier_attn') ffd_output = positionwise_feed_forward( x=pre_process_layer( out=hier_attn_output, # add layer normalization process_cmd=preprocess_cmd, dropout_rate=prepostprocess_dropout, name=name + '_pre_ffn'), d_inner_hid=d_inner_hid, d_hid=d_model, dropout_rate=relu_dropout, hidden_act=hidden_act, param_initializer=param_initializer, name=name + '_ffn') # add dropout and residual connection dec_output = post_process_layer(prev_out=hier_attn_output, out=ffd_output, process_cmd=postprocess_cmd, dropout_rate=prepostprocess_dropout, name=name + '_post_ffn') return dec_output # (batch_size, tgt_len, emb_dim)