def local_within_block_attention(x, self_attention_bias, hparams, attention_type="local_within_block_mask_right", q_padding="VALID", kv_padding="VALID"): """Local within block self attention.""" x_new, x_shape, is_4d = maybe_reshape_4d_to_3d(x) with tf.variable_scope("local_within_block"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x_new, hparams), None, self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=attention_type, block_width=hparams.block_width, block_length=hparams.block_length, q_padding=q_padding, kv_padding=kv_padding, q_filter_width=hparams.q_filter_width, kv_filter_width=hparams.kv_filter_width, name="local_within_block") if is_4d: y = tf.reshape(y, x_shape) return y
def full_self_attention(x, self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT"): """Full self-attention layer.""" x, x_shape, is_4d = maybe_reshape_4d_to_3d(x) if self_attention_bias is not None: self_attention_bias = get_self_attention_bias(x) with tf.variable_scope("self_att"): y = common_attention.multihead_attention( x, None, self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, q_filter_width=hparams.q_filter_width, kv_filter_width=hparams.kv_filter_width, q_padding=q_padding, kv_padding=kv_padding, name="self_att") if is_4d: y = tf.reshape(y, [x_shape[0], x_shape[1], x_shape[2], x_shape[3]]) y.set_shape([None, None, None, hparams.hidden_size]) return y
def dilated_attention_1d(x, hparams, attention_type="masked_dilated_1d", q_padding="VALID", kv_padding="VALID", gap_size=2): """Dilated 1d self attention.""" # self-attention x, x_shape, is_4d = maybe_reshape_4d_to_3d(x) with tf.variable_scope("masked_dilated_1d"): y = common_attention.multihead_attention( x, None, None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=attention_type, block_width=hparams.block_width, block_length=hparams.block_length, q_padding=q_padding, kv_padding=kv_padding, q_filter_width=hparams.q_filter_width, kv_filter_width=hparams.kv_filter_width, gap_size=gap_size, num_memory_blocks=hparams.num_memory_blocks, name="self_attention") if is_4d: y = tf.reshape(y, x_shape) y.set_shape([None, None, None, hparams.hidden_size]) return y
def local_attention_1d(x, hparams, attention_type="local_unmasked", q_padding="VALID", kv_padding="VALID"): """Local 1d self attention.""" # self-attention x, x_shape, is_4d = maybe_reshape_4d_to_3d(x) with tf.variable_scope("local_1d_self_att"): y = common_attention.multihead_attention( x, None, None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=attention_type, shared_rel=hparams.shared_rel, block_width=hparams.block_width, block_length=hparams.block_length, q_padding=q_padding, kv_padding=kv_padding, q_filter_width=hparams.q_filter_width, kv_filter_width=hparams.kv_filter_width, make_image_summary=False, name="self_attention") if is_4d: y = tf.reshape(y, x_shape) return y
def local_global_attention(x, self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT"): """Local and global 1d self attention.""" with tf.variable_scope("self_local_global_att"): [x_global, x_local] = tf.split(x, 2, axis=-1) split_hidden_size = int(hparams.hidden_size / 2) split_heads = int(hparams.num_heads / 2) if self_attention_bias is not None: self_attention_bias = get_self_attention_bias(x) y_global = common_attention.multihead_attention( x_global, None, self_attention_bias, hparams.attention_key_channels or split_hidden_size, hparams.attention_value_channels or split_hidden_size, split_hidden_size, split_heads, hparams.attention_dropout, q_filter_width=hparams.q_filter_width, kv_filter_width=hparams.kv_filter_width, q_padding=q_padding, kv_padding=kv_padding, name="global_self_att") y_local = common_attention.multihead_attention( x_local, None, None, hparams.attention_key_channels or split_hidden_size, hparams.attention_value_channels or split_hidden_size, split_hidden_size, split_heads, hparams.attention_dropout, attention_type="local_masked", block_length=hparams.block_length, block_width=hparams.block_width, q_filter_width=hparams.q_filter_width, kv_filter_width=hparams.kv_filter_width, q_padding=q_padding, kv_padding=kv_padding, name="local_self_att") y = tf.concat([y_global, y_local], axis=-1) return y
def encdec_attention_1d(x, encoder_output, encoder_decoder_attention_bias, hparams): """Local 1d self attention.""" x, x_shape, is_4d = maybe_reshape_4d_to_3d(x) encoder_output, _, _ = maybe_reshape_4d_to_3d(encoder_output) with tf.variable_scope("encdec_attention"): # Encoder Decoder attention y = common_attention.multihead_attention( x, encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, name="encdec_attention") if is_4d: y = tf.reshape(y, x_shape) y.set_shape([None, None, None, hparams.hidden_size]) return y
def evolved_transformer_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder", nonpadding=None, save_weights_to=None, make_image_summary=True, losses=None, attn_bias_for_padding=None): """Evolved Transformer encoder. See arxiv.org/abs/1901.11117 for more details. Note: Pad remover is not supported. Args: encoder_input: a Tensor. encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()). hparams: hyperparameters for model. name: a string. nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This must either be passed in, which we do for "packed" datasets, or inferred from encoder_self_attention_bias. The knowledge about padding is used for pad_remover(efficiency) and to mask out padding in convolutional layers. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. losses: Not used. attn_bias_for_padding: Padded attention bias in case a unidirectional encoder is being used where future attention is masked. Returns: Tensor encoder output. """ del losses hidden_state = encoder_input attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name): if nonpadding is not None: padding = 1.0 - nonpadding else: attention_bias = encoder_self_attention_bias if attn_bias_for_padding is not None: attention_bias = attn_bias_for_padding padding = common_attention.attention_bias_to_padding( attention_bias) nonpadding = 1.0 - padding for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("gated_linear_unit"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) values = tf.layers.dense(hidden_state, hparams.hidden_size) gates = tf.layers.dense(hidden_state, hparams.hidden_size, activation=tf.nn.sigmoid) hidden_state = values * gates hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) with tf.variable_scope("conv_branches"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) # Mask padding from conv layers. mask = tf.tile(tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size]) hidden_state *= mask left_output_dim = int(hparams.hidden_size * 4) left_state = tf.layers.dense(hidden_state, left_output_dim, activation=tf.nn.relu) left_state = tf.nn.dropout( left_state, 1 - hparams.layer_prepostprocess_dropout) right_output_dim = int(hparams.hidden_size / 2) right_state = tf.layers.conv1d(hidden_state, right_output_dim, 3, padding="SAME", name="standard_conv_3x1", activation=tf.nn.relu) right_state = tf.nn.dropout( right_state, 1 - hparams.layer_prepostprocess_dropout) right_state = tf.pad( right_state, [[0, 0], [0, 0], [0, left_output_dim - right_output_dim]], constant_values=0) hidden_state = left_state + right_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) # Mask padding from conv layer. mask = tf.tile(tf.expand_dims(nonpadding, 2), [1, 1, left_output_dim]) hidden_state *= mask separable_conv_9x1 = tf.layers.SeparableConv1D( right_output_dim, 9, padding="SAME", name="separable_conv_9x1") hidden_state = separable_conv_9x1.apply(hidden_state) hidden_state = tf.pad( hidden_state, [[0, 0], [0, 0], [0, hparams.hidden_size - right_output_dim]], constant_values=0) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) with tf.variable_scope("self_attention"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) hidden_state = common_attention.multihead_attention( hidden_state, None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams.add_relative_to_values, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d"), activation_dtype=hparams.get("activation_dtype", "float32"), weight_dtype=hparams.get("weight_dtype", "float32")) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) with tf.variable_scope("dense_layers"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) hidden_state = tf.layers.dense(hidden_state, int(hparams.hidden_size * 4), activation=tf.nn.relu) hidden_state = tf.nn.dropout( hidden_state, 1 - hparams.layer_prepostprocess_dropout) hidden_state = tf.layers.dense(hidden_state, hparams.hidden_size) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) # If normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(hidden_state, hparams)
def transformer_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder", nonpadding=None, save_weights_to=None, make_image_summary=True, losses=None, attn_bias_for_padding=None): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This must either be passed in, which we do for "packed" datasets, or inferred from encoder_self_attention_bias. The knowledge about padding is used for pad_remover(efficiency) and to mask out padding in convolutional layers. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. losses: optional list onto which to append extra training losses attn_bias_for_padding: Padded attention bias in case a unidirectional encoder is being used where future attention is masked. Returns: y: a Tensors """ x = encoder_input attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) mlperf_log.transformer_print(key=mlperf_log.MODEL_HP_NUM_HIDDEN_LAYERS, value=hparams.num_encoder_layers or hparams.num_hidden_layers) mlperf_log.transformer_print(key=mlperf_log.MODEL_HP_ATTENTION_DROPOUT, value=hparams.attention_dropout) mlperf_log.transformer_print(key=mlperf_log.MODEL_HP_ATTENTION_DENSE, value={ "use_bias": "false", "num_heads": hparams.num_heads, "hidden_size": hparams.hidden_size }) with tf.variable_scope(name): if nonpadding is not None: padding = 1.0 - nonpadding else: attention_bias = encoder_self_attention_bias if attn_bias_for_padding is not None: attention_bias = attn_bias_for_padding padding = common_attention.attention_bias_to_padding( attention_bias) nonpadding = 1.0 - padding pad_remover = None if hparams.use_pad_remover and not common_layers.is_xla_compiled(): pad_remover = expert_utils.PadRemover(padding) for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams.add_relative_to_values, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d"), activation_dtype=hparams.get("activation_dtype", "float32"), weight_dtype=hparams.get("weight_dtype", "float32")) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer(common_layers.layer_preprocess( x, hparams), hparams, pad_remover, conv_padding="SAME", nonpadding_mask=nonpadding, losses=losses) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_NORM, value={"hidden_size": hparams.hidden_size}) return common_layers.layer_preprocess(x, hparams)