def attn_over_sent_and_lex_2d(x_slices, pad_remover_combined, hparams): with tf.variable_scope("self_attention"): query_antecedent = common_layers.layer_preprocess(x_slices, hparams) y_slices = common_attention.multihead_attention_2d( query_antecedent=query_antecedent, memory_antecedent=None, total_key_depth=hparams.attention_key_channels or hparams.hidden_size, total_value_depth=hparams.attention_value_channels or hparams.hidden_size, output_depth=hparams.hidden_size, num_heads=hparams.num_heads, query_shape=(4, 4), memory_flange=(4, 4)) x_slices = common_layers.layer_postprocess(x_slices, y_slices, hparams) with tf.variable_scope("ffn"): x0_slices = common_layers.layer_preprocess(x_slices, hparams) x0_slices, batch_size, sent_len, lex_cap, hid_dim = reshape_2d( x0_slices) y_slices = transformer.transformer_ffn_layer(x0_slices, hparams, pad_remover_combined) y_slices = tf.reshape(y_slices, [batch_size, sent_len, lex_cap, hid_dim]) x_slices = common_layers.layer_postprocess(x_slices, y_slices, hparams) return x_slices
def local_attention_2d(x, hparams, attention_type="local_attention_2d"): """Local 2d, self attention layer.""" # self-attention with tf.variable_scope("local_2d_self_att"): y = common_attention.multihead_attention_2d( x, None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, attention_type=attention_type, query_shape=hparams.query_shape, memory_flange=hparams.memory_flange, name="self_attention") return y
def local_attention_2d(x, hparams, attention_type="local_attention_2d"): """Local 2d, self attention layer.""" # self-attention with tf.variable_scope("local_2d_self_att"): y = common_attention.multihead_attention_2d( x, None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, attention_type=attention_type, query_shape=hparams.query_shape, memory_flange=hparams.memory_flange, name="self_attention") return y
def attn_over_sent_and_lex_2d_dec(x, encoder_output, decoder_self_attention_bias, hparams): with tf.variable_scope("self_attention"): query_antecedent = common_layers.layer_preprocess(x, hparams) y = common_attention.multihead_attention( query_antecedent=query_antecedent, memory_antecedent=None, bias=decoder_self_attention_bias, total_key_depth=hparams.attention_key_channels or hparams.hidden_size, total_value_depth=hparams.attention_value_channels or hparams.hidden_size, output_depth=hparams.hidden_size, num_heads=hparams.num_heads, dropout_rate=hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position) x = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): query_antecedent = common_layers.layer_preprocess(x, hparams) batch_size = tf.shape(encoder_output)[0] src_len = tf.shape(encoder_output)[1] tgt_len = tf.shape(query_antecedent)[1] lex_cap = encoder_output.shape.as_list()[2] hid_size = encoder_output.shape.as_list()[3] query_antecedent = tf.expand_dims(query_antecedent, 2) query_antecedent = tf.pad( query_antecedent, [[0, 0], [0, 0], [0, lex_cap - 1], [0, 0]]) query_pad = tf.zeros([batch_size, src_len, lex_cap, hid_size]) query_antecedent = tf.concat([query_antecedent, query_pad], 1) memory_antecedent = encoder_output memory_pad = tf.zeros([batch_size, tgt_len, lex_cap, hid_size]) memory_antecedent = tf.concat([memory_antecedent, memory_pad], 1) tf.logging.info( "dimension of decoder input at the enc-dec attention layer: {0}" .format(query_antecedent.get_shape())) tf.logging.info( "dimension of encoder output at the enc-dec attention layer: {0}" .format(memory_antecedent.get_shape())) y = common_attention.multihead_attention_2d( query_antecedent=query_antecedent, memory_antecedent=memory_antecedent, total_key_depth=hparams.attention_key_channels or hparams.hidden_size, total_value_depth=hparams.attention_value_channels or hparams.hidden_size, output_depth=hparams.hidden_size, num_heads=hparams.num_heads, attention_type="masked_local_attention_2d", query_shape=(4, 4), memory_flange=(4, 4)) tf.logging.info("dimension of enc-dec output: {0}".format( y.get_shape())) y = y[:, :, 0, :] y = y[:, :tgt_len, :] x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): x0 = common_layers.layer_preprocess(x, hparams) y = transformer.transformer_ffn_layer(x0, hparams) x = common_layers.layer_postprocess(x, y, hparams) return x