def f(x, side_input): """f(x) for reversible layer, self-attention and enc-dec attention.""" decoder_self_attention_bias = side_input[0] encoder_decoder_attention_bias = side_input[1] encoder_output = side_input[2] old_hid_size = hparams.hidden_size hparams.hidden_size = old_hid_size // 2 with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess( x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) y = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess( x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) y = common_layers.layer_postprocess(x, y, hparams) hparams.hidden_size = old_hid_size return y
def transformer_encoder_layers(inputs, num_layers, hparams, attention_type=AttentionType.GLOBAL, self_attention_bias=None, q_padding="VALID", kv_padding="VALID", name="transformer"): """Multi layer transformer encoder.""" x = inputs x = tf.nn.dropout(x, 1.0 - hparams.layer_prepostprocess_dropout) for layer in range(num_layers): # attention layers + skip connections with tf.variable_scope("%s_layer_%d" % (name, layer)): if attention_type == AttentionType.LOCAL_2D: y = local_attention_2d(common_layers.layer_preprocess(x, hparams), hparams, attention_type="local_attention_2d") elif attention_type == AttentionType.LOCAL_1D: y = local_attention_1d(common_layers.layer_preprocess(x, hparams), hparams, attention_type="local_unmasked", q_padding=q_padding, kv_padding=kv_padding) elif attention_type == AttentionType.GLOBAL: y = full_self_attention(common_layers.layer_preprocess(x, hparams), self_attention_bias, hparams, q_padding=q_padding, kv_padding=kv_padding) x = common_layers.layer_postprocess(x, y, hparams) # feed-fwd layer + skip connections y = ffn_layer(common_layers.layer_preprocess(x, hparams), hparams) x = common_layers.layer_postprocess(x, y, hparams) return common_layers.layer_preprocess(x, hparams)
def image_encoder(image_feat, hparams, name="image_encoder", save_weights_to=None, make_image_summary=True): """A stack of self attention layers.""" x = image_feat image_hidden_size = hparams.image_hidden_size or hparams.hidden_size image_filter_size = hparams.image_filter_size or hparams.filter_size with tf.variable_scope(name): for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = vqa_layers.multihead_attention( common_layers.layer_preprocess(x, hparams), None, None, hparams.attention_key_channels or image_hidden_size, hparams.attention_value_channels or image_hidden_size, image_hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.image_self_attention_type, save_weights_to=save_weights_to, make_image_summary=make_image_summary, scale_dotproduct=hparams.scale_dotproduct, ) utils.collect_named_outputs( "norms", "image_feat_self_attention_%d"%(layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs( "norms", "image_feat_self_attention_postprocess_%d"%(layer), tf.norm(x, axis=-1)) with tf.variable_scope("ffn"): y = common_layers.dense_relu_dense( common_layers.layer_preprocess(x, hparams), image_filter_size, image_hidden_size, dropout=hparams.relu_dropout, ) utils.collect_named_outputs( "norms", "image_feat_ffn_%d"%(layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs( "norms", "image_feat_ffn_postprocess_%d"%(layer), tf.norm(x, axis=-1)) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def residual_block_layer(inputs, hparams): """Residual block over inputs. Runs a residual block consisting of conv: kernel_size x kernel_size conv: 1x1 dropout, add and normalize according to hparams.layer_postprocess_sequence. Args: inputs: Tensor of shape [batch, height, width, hparams.hidden_size]. hparams: tf.contrib.training.HParams. Returns: Tensor of shape [batch, height, width, hparams.hidden_size]. """ kernel = (hparams.res_kernel_size, hparams.res_kernel_size) x = inputs for i in range(hparams.num_res_layers): with tf.variable_scope("res_conv_%d" % i): # kernel_size x kernel_size conv block y = common_layers.conv_block( common_layers.layer_norm(x, hparams.hidden_size, name="lnorm"), hparams.hidden_size, [((1, 1), kernel)], strides=(1, 1), padding="SAME", name="residual_conv") # 1x1 conv block y = common_layers.conv_block( y, hparams.hidden_size, [((1, 1), (1, 1))], strides=(1, 1), padding="SAME", name="residual_dense") x = common_layers.layer_postprocess(x, y, hparams) return x
def transformer_layers_sharded(dp, ps_devices, inputs, num_layers, hparams, self_attention_bias=None, enc_output=None, attention_type=AttentionType.GLOBAL, name="transformer"): """Multi layer transformer, sharded by the data parallelism dp.""" x = inputs extra_loss = tf.constant(0.0) moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] expert_fn = expert_utils.ffn_expert_fn( hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) for layer in range(num_layers): with tf.variable_scope("%s_layer_%d" % (name, layer)): # self-attention if attention_type == AttentionType.LOCAL_2D: y = dp(local_attention_2d(common_layers.layer_preprocess(x, hparams), hparams, attention_type="masked_local_attention_2d")) elif attention_type == AttentionType.LOCAL_1D: y = dp(local_attention_1d(common_layers.layer_preprocess(x, hparams), hparams, attention_type="local_mask_right", q_padding="LEFT", kv_padding="LEFT")) elif attention_type == AttentionType.GLOCAL: y = dp(local_global_attention( common_layers.layer_preprocess(x, hparams), self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT")) elif attention_type == AttentionType.GLOBAL: self_attention_bias = dp(get_self_attention_bias(x)) y = dp(full_self_attention(common_layers.layer_preprocess(x, hparams), self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT")) x = common_layers.layer_postprocess(x, y, hparams) if enc_output is not None: y = dp(encdec_attention_1d(common_layers.layer_preprocess(x, hparams), enc_output, None, hparams)) x = dp(common_layers.layer_postprocess, x, y, hparams) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers_decoder.split(","): y, loss = expert_utils.distributed_moe( dp, ps_devices, common_layers.layer_preprocess(x, hparams), hparams.mode == tf.estimator.ModeKeys.TRAIN, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_loss += loss x = dp(common_layers.layer_postprocess, x, y, hparams) else: y = dp(ffn_layer, common_layers.layer_preprocess(x, hparams), hparams) x = dp(common_layers.layer_postprocess, x, y, hparams) return dp(common_layers.layer_preprocess, x, hparams), extra_loss
def transformer_decoder_layers(inputs, encoder_output, bias, num_layers, hparams, attention_type=AttentionType.LOCAL_2D, name="transformer"): """Multi layer transformer.""" x = inputs x = tf.nn.dropout(x, 1.0 - hparams.layer_prepostprocess_dropout) if attention_type == AttentionType.DILATED: assert len(hparams.gap_sizes) == num_layers for layer in xrange(num_layers): with tf.variable_scope("%s_layer_%d" % (name, layer)): # self-attention + skip connections if attention_type == AttentionType.LOCAL_2D: y = local_attention_2d(common_layers.layer_preprocess(x, hparams), hparams, attention_type="masked_local_attention_2d") elif attention_type == AttentionType.LOCAL_1D: y = local_attention_1d(common_layers.layer_preprocess(x, hparams), bias, hparams, attention_type="local_mask_right", q_padding="LEFT", kv_padding="LEFT") elif attention_type == AttentionType.GLOCAL: y = local_global_attention(common_layers.layer_preprocess(x, hparams), bias, hparams, q_padding="LEFT", kv_padding="LEFT") elif attention_type == AttentionType.DILATED: y = dilated_attention_1d(common_layers.layer_preprocess(x, hparams), bias, hparams, q_padding="LEFT", kv_padding="LEFT", gap_size=hparams.gap_sizes[layer]) elif attention_type == AttentionType.GLOBAL: y = full_self_attention(common_layers.layer_preprocess(x, hparams), bias, hparams, q_padding="LEFT", kv_padding="LEFT") x = common_layers.layer_postprocess(x, y, hparams) # enc-dec attention + skip connections if encoder_output is not None: y = encdec_attention_1d(common_layers.layer_preprocess(x, hparams), encoder_output, hparams) x = common_layers.layer_postprocess(x, y, hparams) # feed-fwd layers + skip connections y = ffn_layer(common_layers.layer_preprocess(x, hparams), hparams) x = common_layers.layer_postprocess(x, y, hparams) return common_layers.layer_preprocess(x, hparams)
def g(x): """g(x) for reversible layer, feed-forward layer.""" old_hid_size = hparams.hidden_size hparams.hidden_size = old_hid_size // 2 with tf.variable_scope("ffn"): y = transformer.transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) y = common_layers.layer_postprocess(x, y, hparams) hparams.hidden_size = old_hid_size return y
def attend(x, source, hparams, name): with tf.variable_scope(name): x = tf.squeeze(x, axis=2) if len(source.get_shape()) > 3: source = tf.squeeze(source, axis=2) source = common_attention.add_timing_signal_1d(source) y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), source, None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) res = common_layers.layer_postprocess(x, y, hparams) return tf.expand_dims(res, axis=2)
def compress_self_attention_layer(x, hparams, name=None): """Attend function.""" with tf.variable_scope(name, default_name="compress_self_attention"): x, xshape, _ = cia.maybe_reshape_4d_to_3d(x) y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) res = common_layers.layer_postprocess(x, y, hparams) return tf.reshape(res, xshape)
def attention_lm_decoder(decoder_input, decoder_self_attention_bias, hparams, name="decoder"): """A stack of attention_lm layers. Args: decoder_input: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string Returns: y: a Tensors """ x = decoder_input with tf.variable_scope(name): for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess( x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = common_layers.conv_hidden_relu( common_layers.layer_preprocess(x, hparams), hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout) x = common_layers.layer_postprocess(x, y, hparams) return common_layers.layer_preprocess(x, hparams)
def attend(x, source, hparams, name): """Attend function.""" with tf.variable_scope(name): # x = tf.squeeze(x, axis=2) x, xshape, _ = cia.maybe_reshape_4d_to_3d(x) if len(source.get_shape()) > 3: source = tf.squeeze(source, axis=2) source = common_attention.add_timing_signal_1d(source) y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), source, None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) res = common_layers.layer_postprocess(x, y, hparams) return tf.reshape(res, xshape)
def transformer_encoder_gate(encoder_input, encoder_self_attention_bias, hparams, name="encoder"): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string Returns: y: a Tensors """ x = encoder_input with tf.variable_scope(name): pad_remover = None if hparams.use_pad_remover: pad_remover = expert_utils.PadRemover( common_attention.attention_bias_to_padding( encoder_self_attention_bias)) for layer in xrange(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position) x = common_layers.layer_postprocess(x, y, hparams) gate_fiter = tf.get_variable( 'gate_layer_%d' % layer, [1, hparams.hidden_size, hparams.hidden_size], tf.float32, initializer=tf.contrib.layers.xavier_initializer()) gate_x = tf.tanh( tf.nn.conv1d(x, gate_fiter, 1, 'SAME')) x *= gate_x with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, pad_remover) x = common_layers.layer_postprocess(x, y, hparams) gate_fiter = tf.get_variable( 'gate_layer_%d' % layer, [1, hparams.hidden_size, hparams.hidden_size], tf.float32, initializer=tf.contrib.layers.xavier_initializer()) gate_x = tf.tanh( tf.nn.conv1d(x, gate_fiter, 1, 'SAME')) x *= gate_x # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_decoder_fast(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=None, name="decoder"): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_position_forward_mask: mask Tensor for position-forward / shape: [1, t, 1] encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. name: a string Returns: y: a Tensors """ x = decoder_input with tf.variable_scope(name): for layer in range(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): y = multihead_attention( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer.transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_decoder(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=None, name="decoder", ctxly=None, model_config=None): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. name: a string Returns: y: a Tensors """ x = decoder_input contexts = None if ctxly is None: ctxly = (hparams.num_decoder_layers or hparams.num_hidden_layers) - 1 print('Use Context layer %s for output' % ctxly) with tf.variable_scope(name): for layer in xrange(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): # TODO(llion): Add caching. y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, model_config=model_config, decoder_input=decoder_input) if layer == ctxly: contexts = tf.identity(y) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer(common_layers.layer_preprocess( x, hparams), hparams, conv_padding="LEFT") x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams), contexts
def image_question_encoder(encoder_inputs, encoder_self_attention_bias, hparams, query=None, name="image_question_encoder", save_weights_to=None, make_image_summary=True): """A stack of self attention layers.""" x = encoder_inputs with tf.variable_scope(name): for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = vqa_layers.multihead_attention( common_layers.layer_preprocess(x, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, block_length=hparams.block_length, save_weights_to=save_weights_to, make_image_summary=make_image_summary, scale_dotproduct=hparams.scale_dotproduct, ) utils.collect_named_outputs( "norms", "encoder_self_attention_%d" % (layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs( "norms", "encoder_self_attention_postprocess_%d" % (layer), tf.norm(x, axis=-1)) if query is not None: with tf.variable_scope("encdec_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), query, None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, block_length=hparams.block_length, save_weights_to=save_weights_to, make_image_summary=make_image_summary, scale_dotproduct=hparams.scale_dotproduct, ) utils.collect_named_outputs( "norms", "encoder_decoder_attention_%d" % (layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs( "norms", "encoder_decoder_attention_post_%d" % (layer), tf.norm(x, axis=-1)) with tf.variable_scope("ffn"): y = common_layers.dense_relu_dense( common_layers.layer_preprocess(x, hparams), hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout, ) utils.collect_named_outputs("norms", "encoder_ffn_%d" % (layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs( "norms", "encoder_ffn_postprocess_%d" % (layer), tf.norm(x, axis=-1)) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_decoder_layers(inputs, encoder_output, num_layers, hparams, self_attention_bias=None, encoder_decoder_attention_bias=None, attention_type=AttentionType.LOCAL_2D, name="transformer"): """Multi layer transformer.""" x = inputs x = tf.nn.dropout(x, 1.0 - hparams.layer_prepostprocess_dropout) if attention_type == AttentionType.DILATED: assert len(hparams.gap_sizes) == num_layers for layer in xrange(num_layers): with tf.variable_scope("%s_layer_%d" % (name, layer)): # self-attention + skip connections if attention_type == AttentionType.LOCAL_2D: y = local_attention_2d( common_layers.layer_preprocess(x, hparams), hparams, attention_type="masked_local_attention_2d") elif attention_type == AttentionType.LOCAL_1D: y = local_attention_1d(common_layers.layer_preprocess( x, hparams), hparams, attention_type="local_mask_right", q_padding="LEFT", kv_padding="LEFT") elif attention_type == AttentionType.LOCAL_BLOCK: y = local_within_block_attention( common_layers.layer_preprocess(x, hparams), self_attention_bias, hparams, attention_type="local_within_block_mask_right", q_padding="LEFT", kv_padding="LEFT") elif attention_type == AttentionType.GLOCAL: y = local_global_attention(common_layers.layer_preprocess( x, hparams), self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT") elif attention_type == AttentionType.DILATED: y = dilated_attention_1d(common_layers.layer_preprocess( x, hparams), hparams, q_padding="LEFT", kv_padding="LEFT", gap_size=hparams.gap_sizes[layer]) elif attention_type == AttentionType.GLOBAL: y = full_self_attention(common_layers.layer_preprocess( x, hparams), self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT") x = common_layers.layer_postprocess(x, y, hparams) # enc-dec attention + skip connections if encoder_output is not None: y = encdec_attention_1d( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams) x = common_layers.layer_postprocess(x, y, hparams) # feed-fwd layers + skip connections y = ffn_layer(common_layers.layer_preprocess(x, hparams), hparams) x = common_layers.layer_postprocess(x, y, hparams) return common_layers.layer_preprocess(x, hparams)
def evolved_transformer_decoder(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=None, decode_loop_step=None, name="decoder", nonpadding=None, save_weights_to=None, make_image_summary=True, losses=None): """Evolved Transformer decoder. See arxiv.org/abs/1901.11117 for more details. Args: decoder_input: a Tensor. encoder_output: a Tensor. decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()). encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()). hparams: hyperparameters for model. cache: Not supported. decode_loop_step: An integer, step number of the decoding loop. Only used for inference on TPU. name: a string. nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This is used to mask out padding in convolutional layers. We generally only need this mask for "packed" datasets, because for ordinary datasets, no padding is ever followed by nonpadding. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. losses: Not supported. Returns: Decoder output tensor. """ del cache, losses attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name): hidden_state = decoder_input layer_cache = None for layer in range(hparams.num_decoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("16_head_self_attention"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) # 16 head attention. Hard coding number of heads. left_state = common_attention.multihead_attention( hidden_state, None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, 16, # Heads are hard coded to replicate paper. hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams.add_relative_to_values, save_weights_to=save_weights_to, cache=layer_cache, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), decode_loop_step=decode_loop_step, vars_3d=hparams.get("attention_variables_3d"), activation_dtype=hparams.get("activation_dtype", "float32"), weight_dtype=hparams.get("weight_dtype", "float32")) if encoder_output is not None: with tf.variable_scope("first_attend_to_encoder"): right_state = common_attention.multihead_attention( hidden_state, encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, max_relative_position=hparams. max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams. add_relative_to_values, save_weights_to=save_weights_to, cache=layer_cache, make_image_summary=make_image_summary, dropout_broadcast_dims= attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d"), activation_dtype=hparams.get( "activation_dtype", "float32"), weight_dtype=hparams.get("weight_dtype", "float32")) left_state = tf.nn.dropout( left_state, 1 - hparams.layer_prepostprocess_dropout) right_state = tf.nn.dropout( right_state, 1 - hparams.layer_prepostprocess_dropout) hidden_state = residual_state + left_state + right_state else: hidden_state = common_layers.layer_postprocess( residual_state, left_state, hparams) with tf.variable_scope("conv_branches"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) if nonpadding is not None: # Mask padding from conv layers. mask = tf.tile(tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size]) hidden_state *= mask # Shift inputs so that future tokens cannot be seen. left_state = tf.pad(hidden_state, paddings=[[0, 0], [10, 0], [0, 0]]) left_output_dim = int(hparams.hidden_size * 2) separable_conv_11x1 = tf.layers.SeparableConv1D( left_output_dim, 11, padding="VALID", name="separable_conv11x1", activation=tf.nn.relu) left_state = separable_conv_11x1.apply(left_state) left_state = tf.nn.dropout( left_state, 1 - hparams.layer_prepostprocess_dropout) right_state = tf.pad(hidden_state, paddings=[[0, 0], [6, 0], [0, 0]]) right_output_dim = int(hparams.hidden_size / 2) separable_conv_7x1_1 = tf.layers.SeparableConv1D( right_output_dim, 7, padding="VALID", name="separable_conv_7x1_1") right_state = separable_conv_7x1_1.apply(right_state) right_state = tf.nn.dropout( right_state, 1 - hparams.layer_prepostprocess_dropout) right_state = tf.pad( right_state, [[0, 0], [0, 0], [0, left_output_dim - right_output_dim]], constant_values=0) hidden_state = left_state + right_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) if nonpadding is not None: # Mask padding from conv layers. mask = tf.tile(tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size * 2]) hidden_state *= mask hidden_state = tf.pad(hidden_state, paddings=[[0, 0], [6, 0], [0, 0]]) separable_conv_7x1_2 = tf.layers.SeparableConv1D( hparams.hidden_size, 7, padding="VALID", name="separable_conv_7x1_2") hidden_state = separable_conv_7x1_2.apply(hidden_state) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) with tf.variable_scope("self_attention"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) hidden_state = common_attention.multihead_attention( hidden_state, None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams.add_relative_to_values, save_weights_to=save_weights_to, cache=layer_cache, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), decode_loop_step=decode_loop_step, vars_3d=hparams.get("attention_variables_3d"), activation_dtype=hparams.get("activation_dtype", "float32"), weight_dtype=hparams.get("weight_dtype", "float32")) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) if encoder_output is not None: with tf.variable_scope("second_attend_to_encoder"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) hidden_state = common_attention.multihead_attention( hidden_state, encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, max_relative_position=hparams. max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams. add_relative_to_values, save_weights_to=save_weights_to, cache=layer_cache, make_image_summary=make_image_summary, dropout_broadcast_dims= attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d"), activation_dtype=hparams.get( "activation_dtype", "float32"), weight_dtype=hparams.get("weight_dtype", "float32")) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) with tf.variable_scope("dense_layers"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) hidden_state = tf.layers.dense(hidden_state, int(hparams.hidden_size * 4), activation=tf.nn.swish) hidden_state = tf.nn.dropout( hidden_state, 1 - hparams.layer_prepostprocess_dropout) hidden_state = common_layers.layer_preprocess( hidden_state, hparams) hidden_state = tf.layers.dense(hidden_state, hparams.hidden_size) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) return common_layers.layer_preprocess(hidden_state, hparams)
def ffn(x, hparams, name): with tf.variable_scope(name): y = transformer.transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) return common_layers.layer_postprocess(x, y, hparams)
def transformer_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder", nonpadding=None, save_weights_to=None): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This must either be passed in, which we do for "packed" datasets, or inferred from encoder_self_attention_bias. The knowledge about padding is used for pad_remover(efficiency) and to mask out padding in convoltutional layers. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). Returns: y: a Tensors """ x = encoder_input with tf.variable_scope(name): if nonpadding is not None: padding = 1.0 - nonpadding else: padding = common_attention.attention_bias_to_padding( encoder_self_attention_bias) nonpadding = 1.0 - padding pad_remover = None if hparams.use_pad_remover: pad_remover = expert_utils.PadRemover(padding) for layer in xrange(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, max_relative_position=hparams.max_relative_position) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer(common_layers.layer_preprocess( x, hparams), hparams, pad_remover, conv_padding="SAME", nonpadding_mask=nonpadding) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_decoder_gate(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=None, name="decoder"): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. name: a string Returns: y: a Tensors """ x = decoder_input with tf.variable_scope(name): for layer in xrange(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) gate_fiter = tf.get_variable( 'gate_layer_%d' % layer, [1, hparams.hidden_size, hparams.hidden_size], tf.float32, initializer=tf.contrib.layers.xavier_initializer()) gate_x = tf.tanh( tf.nn.conv1d(x, gate_fiter, 1, 'SAME')) x *= gate_x if encoder_output is not None: with tf.variable_scope("encdec_attention"): # TODO(llion): Add caching. y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) x = common_layers.layer_postprocess(x, y, hparams) gate_fiter = tf.get_variable( 'gate_layer_%d' % layer, [1, hparams.hidden_size, hparams.hidden_size], tf.float32, initializer=tf.contrib.layers.xavier_initializer()) gate_x = tf.tanh( tf.nn.conv1d(x, gate_fiter, 1, 'SAME')) x *= gate_x with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) x = common_layers.layer_postprocess(x, y, hparams) gate_fiter = tf.get_variable( 'gate_layer_%d' % layer, [1, hparams.hidden_size, hparams.hidden_size], tf.float32, initializer=tf.contrib.layers.xavier_initializer()) gate_x = tf.tanh( tf.nn.conv1d(x, gate_fiter, 1, 'SAME')) x *= gate_x # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_decoder(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=None, name="decoder", nonpadding=None, save_weights_to=None, make_image_summary=True): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This is used to mask out padding in convoltutional layers. We generally only need this mask for "packed" datasets, because for ordinary datasets, no padding is ever followed by nonpadding. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. Returns: y: a Tensors """ x = decoder_input attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name): for layer in xrange(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, max_relative_position=hparams.max_relative_position, cache=layer_cache, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims) x = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): # TODO(llion): Add caching. y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, conv_padding="LEFT", nonpadding_mask=nonpadding) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_decoder(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=None, name="decoder", terminal_decoder_bias=None, nonterminal_decoder_bias=None, nonpadding=None, pos_signals=None): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This is used to mask out padding in convoltutional layers. We generally only need this mask for "packed" datasets, because for ordinary datasets, no padding is ever followed by nonpadding. Returns: y: a Tensors """ x = decoder_input sequence_length = usr_utils.get_length_from_nonpadding(nonpadding) with tf.variable_scope(name): for layer in xrange(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): for layer_type in _iter_layer_types( hparams.decoder_layer_types, layer): if layer_type == "self_att": with tf.variable_scope("self_attention"): y = model_helper.multihead_attention_qkv( common_layers.layer_preprocess(x, hparams), None, None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams. decoder_self_attention_type, attention_order=hparams.attention_order, max_relative_position=hparams. max_relative_position, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "nt_self_att": with tf.variable_scope("nonterminal_self_attention"): y = model_helper.multihead_attention_qkv( common_layers.layer_preprocess(x, hparams), None, None, nonterminal_decoder_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams. decoder_self_attention_type, attention_order=hparams.attention_order, max_relative_position=hparams. max_relative_position, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "t_self_att": with tf.variable_scope("terminal_self_attention"): y = model_helper.multihead_attention_qkv( common_layers.layer_preprocess(x, hparams), None, None, terminal_decoder_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams. decoder_self_attention_type, attention_order=hparams.attention_order, max_relative_position=hparams. max_relative_position, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "parent_ffn": with tf.variable_scope("parent_ffn"): parent_pointers = tf.cast( pos_signals["parent_timing"], tf.int32) parent_x = usr_utils.gather_2d(x, parent_pointers) x = tf.concat([x, parent_x], axis=2) x = transformer_ffn_layer(x, hparams, conv_padding="LEFT") elif layer_type == "rnn": with tf.variable_scope("recurrent"): y = transformer_rnn_layer( common_layers.layer_preprocess(x, hparams), sequence_length, hparams) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "enc_att" and encoder_output is not None: with tf.variable_scope("encdec_attention"): # TODO(llion): Add caching. y = model_helper.multihead_attention_qkv( common_layers.layer_preprocess(x, hparams), encoder_output, None, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) x = common_layers.layer_postprocess(x, y, hparams) else: tf.logging.warn( "Ignoring '%s' in decoder_layer_types" % layer_type) with tf.variable_scope("ffn"): y = transformer_ffn_layer(common_layers.layer_preprocess( x, hparams), hparams, conv_padding="LEFT", nonpadding_mask=nonpadding) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder", nonpadding=None): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This must either be passed in, which we do for "packed" datasets, or inferred from encoder_self_attention_bias. The knowledge about padding is used for pad_remover(efficiency) and to mask out padding in convoltutional layers. Returns: y: a Tensors """ x = encoder_input with tf.variable_scope(name): if nonpadding is not None: padding = 1.0 - nonpadding else: padding = common_attention.attention_bias_to_padding( encoder_self_attention_bias) nonpadding = 1.0 - padding pad_remover = None if hparams.use_pad_remover: pad_remover = expert_utils.PadRemover(padding) sequence_length = usr_utils.get_length_from_nonpadding(nonpadding) for layer in xrange(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): for layer_type in _iter_layer_types( hparams.encoder_layer_types, layer): if layer_type == "self_att": with tf.variable_scope("self_attention"): y = model_helper.multihead_attention_qkv( common_layers.layer_preprocess(x, hparams), None, None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams. encoder_self_attention_type, attention_order=hparams.attention_order, max_relative_position=hparams. max_relative_position) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "rnn": with tf.variable_scope("recurrent"): y = transformer_rnn_layer( common_layers.layer_preprocess(x, hparams), sequence_length, hparams) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "birnn": with tf.variable_scope("recurrent"): y = transformer_rnn_layer( common_layers.layer_preprocess(x, hparams), sequence_length, hparams, bidirectional=True) x = common_layers.layer_postprocess(x, y, hparams) else: tf.logging.warn( "Ignoring '%s' in encoder_layer_types" % layer_type) with tf.variable_scope("ffn"): y = transformer_ffn_layer(common_layers.layer_preprocess( x, hparams), hparams, pad_remover, conv_padding="SAME", nonpadding_mask=nonpadding) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_decoder_fast_aan(decoder_input, encoder_output, decoder_position_forward_mask, encoder_decoder_attention_bias, hparams, cache=None, name="decoder"): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_position_forward_mask: mask Tensor for position-forward / shape: [1, t, 1] encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. name: a string Returns: y: a Tensors """ x = decoder_input with tf.variable_scope(name): for layer in range(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): with tf.variable_scope("position_forward"): if layer_cache: given_inputs_new = layer_cache['given_inputs'] + x x_fwd = given_inputs_new * decoder_position_forward_mask layer_cache['given_inputs'] = given_inputs_new + x else: x_fwd = tf.cumsum( x, axis=1) * decoder_position_forward_mask # FFN activation y = transformer.transformer_ffn_layer( common_layers.layer_preprocess(x_fwd, hparams), hparams) # Gating layer z = tf.layers.dense(tf.concat([x, y], axis=-1), hparams.hidden_size * 2, name="z_project") i, f = tf.split(z, 2, axis=-1) y = tf.sigmoid(i) * x + tf.sigmoid(f) * y x = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): y = multihead_attention( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer.transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_decoder(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=None, name="decoder", return_attention_weight=False): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. name: a string Returns: y: a Tensors """ x = decoder_input additional_outputs = {} if return_attention_weight: additional_outputs["attention_weight"] = [] with tf.variable_scope(name): for layer in xrange(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): # TODO(llion): Add caching. y = common_attention.multihead_attention( common_layers.layer_preprocess( x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, return_attention_weight=return_attention_weight) if return_attention_weight: y, attention_weight = y additional_outputs["attention_weight"].append(attention_weight) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. if additional_outputs != {}: return common_layers.layer_preprocess(x, hparams), additional_outputs return common_layers.layer_preprocess(x, hparams)
def symbols_to_logits_fn(ids, ids_tag, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets_method(targets, i) ids_tag = ids_tag[:, -1:] targets_tag = tf.expand_dims(tf.expand_dims(ids_tag, axis=2), axis=3) targets_tag = preprocess_targets_tag_method(targets_tag, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope('body'): with tf.variable_scope('edit_ops_layer'): with tf.variable_scope('ffn'): x = targets preproc = lambda z: common_layers.layer_preprocess( z, hparams, layer_collection=None) layer_inputs = [ tf.concat(preproc(x), axis=0), tf.concat(preproc(targets_tag), axis=0), ] y = transformer_layers.transformer_ffn_layer( tf.concat(layer_inputs, axis=2), hparams, conv_padding='LEFT', nonpadding_mask=features_to_nonpadding( features, 'targets'), losses=None, cache=cache, decode_loop_step=None, layer_collection=None, ) targets = common_layers.layer_postprocess( x, y, hparams) if hparams.middle_prediction: num_decoder_layers = (hparams.num_decoder_layers or hparams.num_hidden_layers) hparams.num_decoder_layers = int( num_decoder_layers / hparams.middle_prediction_layer_factor) body_outputs = dp( self.decode, targets, cache.get('encoder_output'), cache.get('encoder_decoder_attention_bias'), bias, hparams, cache, nonpadding=features_to_nonpadding(features, 'targets'), )[0] body_outputs, logits_tag = dp( self._prediction_cascade_predict, hparams, features_to_nonpadding(features, 'targets'), cache.get('encoder_decoder_attention_bias'), cache.get('encoder_output'), body_outputs, ) logits_tag = logits_tag[0]['targets_error_tag'] if hparams.middle_prediction: with tf.variable_scope('after_prediction'): body_outputs = dp( self.decode, targets + body_outputs[0], cache.get('encoder_output'), cache.get('encoder_decoder_attention_bias'), bias, hparams, cache, nonpadding=features_to_nonpadding( features, 'targets'), ) update_decoder_attention_history(cache) modality_name = hparams.name.get( 'targets', modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope('targets/' + modality_name): top = hparams.top.get('targets', modalities.get_top(target_modality)) logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0] ret = tf.squeeze(logits, axis=[1, 2]) if partial_targets is not None: vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9, ) ret = tf.cond( tf.less(i, partial_targets_length), forced_logits, lambda: ret, ) logits_tag = tf.squeeze(logits_tag, axis=[1]) return ret, logits_tag, cache
def transformer_decoder(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=None, name="decoder", nonpadding=None, save_weights_to=None): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This is used to mask out padding in convoltutional layers. We generally only need this mask for "packed" datasets, because for ordinary datasets, no padding is ever followed by nonpadding. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). Returns: y: a Tensors """ x = decoder_input with tf.variable_scope(name): for layer in xrange(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, max_relative_position=hparams.max_relative_position, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): # TODO(llion): Add caching. y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, save_weights_to=save_weights_to) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer(common_layers.layer_preprocess( x, hparams), hparams, conv_padding="LEFT", nonpadding_mask=nonpadding) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder", nonpadding=None, save_weights_to=None, make_image_summary=True, losses=None, attn_bias_for_padding=None): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This must either be passed in, which we do for "packed" datasets, or inferred from encoder_self_attention_bias. The knowledge about padding is used for pad_remover(efficiency) and to mask out padding in convolutional layers. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. losses: optional list onto which to append extra training losses attn_bias_for_padding: Padded attention bias in case a unidirectional encoder is being used where future attention is masked. Returns: y: a Tensors """ x = encoder_input attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) mlperf_log.transformer_print(key=mlperf_log.MODEL_HP_NUM_HIDDEN_LAYERS, value=hparams.num_encoder_layers or hparams.num_hidden_layers) mlperf_log.transformer_print(key=mlperf_log.MODEL_HP_ATTENTION_DROPOUT, value=hparams.attention_dropout) mlperf_log.transformer_print(key=mlperf_log.MODEL_HP_ATTENTION_DENSE, value={ "use_bias": "false", "num_heads": hparams.num_heads, "hidden_size": hparams.hidden_size }) with tf.variable_scope(name): if nonpadding is not None: padding = 1.0 - nonpadding else: attention_bias = encoder_self_attention_bias if attn_bias_for_padding is not None: attention_bias = attn_bias_for_padding padding = common_attention.attention_bias_to_padding( attention_bias) nonpadding = 1.0 - padding pad_remover = None if hparams.use_pad_remover and not common_layers.is_xla_compiled(): pad_remover = expert_utils.PadRemover(padding) for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams.add_relative_to_values, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d"), activation_dtype=hparams.get("activation_dtype", "float32"), weight_dtype=hparams.get("weight_dtype", "float32")) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer(common_layers.layer_preprocess( x, hparams), hparams, pad_remover, conv_padding="SAME", nonpadding_mask=nonpadding, losses=losses) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_NORM, value={"hidden_size": hparams.hidden_size}) return common_layers.layer_preprocess(x, hparams)
def transformer_decoder(decoder_input, encoder_output, decoder_self_attention_biases, encoder_decoder_attention_biases, hparams, cache=None, decode_loop_step=None, name="decoder", nonpadding=None, save_weights_to=None, make_image_summary=True, losses=None): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_biases: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_biases: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. decode_loop_step: An integer, step number of the decoding loop. Only used for inference on TPU. name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This is used to mask out padding in convolutional layers. We generally only need this mask for "packed" datasets, because for ordinary datasets, no padding is ever followed by nonpadding. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. losses: optional list onto which to append extra training losses Returns: y: a Tensors """ x = decoder_input attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name): for layer in range(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): for context_type in hparams.transformer_context_types: with tf.variable_scope("self_attention_%s" % context_type): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_biases[context_type], hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams.add_relative_to_values, save_weights_to=save_weights_to, cache=layer_cache, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), decode_loop_step=decode_loop_step, vars_3d=hparams.get("attention_variables_3d")) x = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, max_relative_position=hparams.max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams.add_relative_to_values, save_weights_to=save_weights_to, cache=layer_cache, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d")) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, conv_padding="LEFT", nonpadding_mask=nonpadding, losses=losses, cache=layer_cache, decode_loop_step=decode_loop_step) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def evolved_transformer_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder", nonpadding=None, save_weights_to=None, make_image_summary=True, losses=None, attn_bias_for_padding=None): """Evolved Transformer encoder. See arxiv.org/abs/1901.11117 for more details. Note: Pad remover is not supported. Args: encoder_input: a Tensor. encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()). hparams: hyperparameters for model. name: a string. nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This must either be passed in, which we do for "packed" datasets, or inferred from encoder_self_attention_bias. The knowledge about padding is used for pad_remover(efficiency) and to mask out padding in convolutional layers. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. losses: Not used. attn_bias_for_padding: Padded attention bias in case a unidirectional encoder is being used where future attention is masked. Returns: Tensor encoder output. """ del losses hidden_state = encoder_input attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name): if nonpadding is not None: padding = 1.0 - nonpadding else: attention_bias = encoder_self_attention_bias if attn_bias_for_padding is not None: attention_bias = attn_bias_for_padding padding = common_attention.attention_bias_to_padding( attention_bias) nonpadding = 1.0 - padding for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("gated_linear_unit"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) values = common_layers.layers().Dense( hparams.hidden_size)(hidden_state) gates = common_layers.layers().Dense( hparams.hidden_size, activation=tf.nn.sigmoid)(hidden_state) hidden_state = values * gates hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) with tf.variable_scope("conv_branches"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) # Mask padding from conv layers. mask = tf.tile(tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size]) hidden_state *= mask left_output_dim = int(hparams.hidden_size * 4) left_state = common_layers.layers().Dense( left_output_dim, activation=tf.nn.relu)(hidden_state) left_state = tf.nn.dropout( left_state, 1 - hparams.layer_prepostprocess_dropout) right_output_dim = int(hparams.hidden_size / 2) right_state = common_layers.layers().Conv1D( right_output_dim, 3, padding="SAME", name="standard_conv_3x1", activation=tf.nn.relu)(hidden_state) right_state = tf.nn.dropout( right_state, 1 - hparams.layer_prepostprocess_dropout) right_state = tf.pad( right_state, [[0, 0], [0, 0], [0, left_output_dim - right_output_dim]], constant_values=0) hidden_state = left_state + right_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) # Mask padding from conv layer. mask = tf.tile(tf.expand_dims(nonpadding, 2), [1, 1, left_output_dim]) hidden_state *= mask separable_conv_9x1 = common_layers.layers( ).SeparableConv1D(right_output_dim, 9, padding="SAME", name="separable_conv_9x1") hidden_state = separable_conv_9x1(hidden_state) hidden_state = tf.pad( hidden_state, [[0, 0], [0, 0], [0, hparams.hidden_size - right_output_dim]], constant_values=0) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) with tf.variable_scope("self_attention"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) hidden_state = common_attention.multihead_attention( hidden_state, None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams.add_relative_to_values, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d"), activation_dtype=hparams.get("activation_dtype", "float32"), weight_dtype=hparams.get("weight_dtype", "float32")) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) with tf.variable_scope("dense_layers"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) hidden_state = common_layers.layers().Dense( int(hparams.hidden_size * 4), activation=tf.nn.relu)(hidden_state) hidden_state = tf.nn.dropout( hidden_state, 1 - hparams.layer_prepostprocess_dropout) hidden_state = common_layers.layers().Dense( hparams.hidden_size)(hidden_state) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) # If normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(hidden_state, hparams)
def transformer_encoder(encoder_input, raw_inputs, encoder_self_attention_bias, hparams, name="encoder"): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string Returns: y: a Tensors """ x = encoder_input with tf.variable_scope(name): raw_encoder_input = tf.squeeze(raw_inputs, axis=[-2, -1]) sequence_length = usr_utils.get_length_from_raw( raw_encoder_input) # Used for RNNs pos_signals = generate_positional_signals(raw_encoder_input, hparams) pos_embeddings = generate_positional_embeddings( pos_signals, hparams.encoder_pos, hparams) attention_pos_embeddings = generate_positional_embeddings( pos_signals, hparams.encoder_attention_pos, hparams) if "sum" in hparams.pos_integration: x = x + pos_embeddings elif "ffn" in hparams.pos_integration: with tf.variable_scope("pos_ffn"): x = tf.concat([x, pos_embeddings], axis=2) x = transformer_ffn_layer(x, hparams) pad_remover = None if hparams.use_pad_remover: pad_remover = expert_utils.PadRemover( common_attention.attention_bias_to_padding( encoder_self_attention_bias)) for layer in xrange(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): for layer_type in _iter_layer_types( hparams.encoder_layer_types, layer): if layer_type == "self_att": with tf.variable_scope("self_attention"): y = model_helper.multihead_attention_qkv( common_layers.layer_preprocess(x, hparams), None, None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams. encoder_self_attention_type, attention_order=hparams.attention_order, max_relative_position=hparams. max_relative_position) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "rnn": with tf.variable_scope("recurrent"): y = transformer_rnn_layer( common_layers.layer_preprocess(x, hparams), sequence_length, hparams) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "birnn": with tf.variable_scope("recurrent"): y = transformer_rnn_layer( common_layers.layer_preprocess(x, hparams), sequence_length, hparams, bidirectional=True) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "pos_self_att" and attention_pos_embeddings is not None: with tf.variable_scope("pos_self_attention"): y = model_helper.multihead_attention_qkv( attention_pos_embeddings, # Query attention_pos_embeddings, # Key common_layers.layer_preprocess( x, hparams), # Value encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.pos_self_attention_type, attention_order=hparams.attention_order, max_relative_position=hparams. max_relative_position) x = common_layers.layer_postprocess(x, y, hparams) else: tf.logging.warn( "Ignoring '%s' in encoder_layer_types" % layer_type) with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, pad_remover) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_layers_sharded(dp, ps_devices, inputs, num_layers, hparams, self_attention_bias=None, enc_output=None, attention_type=AttentionType.GLOBAL, name="transformer"): """Multi layer transformer, sharded by the data parallelism dp.""" x = inputs extra_loss = tf.constant(0.0) moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) for layer in xrange(num_layers): with tf.variable_scope("%s_layer_%d" % (name, layer)): # self-attention if attention_type == AttentionType.LOCAL_2D: y = dp( local_attention_2d( common_layers.layer_preprocess(x, hparams), hparams, attention_type="masked_local_attention_2d")) elif attention_type == AttentionType.LOCAL_1D: y = dp( local_attention_1d(common_layers.layer_preprocess( x, hparams), hparams, attention_type="local_mask_right", q_padding="LEFT", kv_padding="LEFT")) elif attention_type == AttentionType.GLOCAL: y = dp( local_global_attention(common_layers.layer_preprocess( x, hparams), self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT")) elif attention_type == AttentionType.GLOBAL: self_attention_bias = dp(get_self_attention_bias(x)) y = dp( full_self_attention(common_layers.layer_preprocess( x, hparams), self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT")) x = common_layers.layer_postprocess(x, y, hparams) if enc_output is not None: y = dp( encdec_attention_1d( common_layers.layer_preprocess(x, hparams), enc_output, None, hparams)) x = dp(common_layers.layer_postprocess, x, y, hparams) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers_decoder.split(","): y, loss = expert_utils.distributed_moe( dp, ps_devices, common_layers.layer_preprocess(x, hparams), hparams.mode == tf.estimator.ModeKeys.TRAIN, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_loss += loss x = dp(common_layers.layer_postprocess, x, y, hparams) else: y = dp(ffn_layer, common_layers.layer_preprocess(x, hparams), hparams) x = dp(common_layers.layer_postprocess, x, y, hparams) return dp(common_layers.layer_preprocess, x, hparams), extra_loss
def body(self, features): assert self._hparams.block_size > 0 assert not common_layers.is_xla_compiled() hparams = copy.copy(self._hparams) targets = features["targets"] inputs = features["inputs"] if not (tf.get_variable_scope().reuse or hparams.mode == tf.estimator.ModeKeys.PREDICT): tf.summary.image("inputs", inputs, max_outputs=1) tf.summary.image("targets", targets, max_outputs=1) encoder_input = cia.prepare_encoder(inputs, hparams) encoder_output = cia.transformer_encoder_layers( encoder_input, hparams.num_encoder_layers, hparams, attention_type=hparams.enc_attention_type, name="encoder") decoder_input, rows, cols = cia.prepare_decoder( targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") assert not isinstance(decoder_output, tuple) assert len(decoder_output.shape) == 4 relu_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(self._hparams, "relu_dropout_broadcast_dims", ""))) with tf.variable_scope("block_size_%d" % self._hparams.block_size): tf.logging.info("Using block_size %d", self._hparams.block_size) block_output = common_layers.dense_relu_dense( decoder_output, self._hparams.block_size * self._hparams.filter_size, self._hparams.block_size * self._hparams.hidden_size, dropout=self._hparams.relu_dropout, dropout_broadcast_dims=relu_dropout_broadcast_dims) batch_size, rows, cols = common_layers.shape_list(decoder_output)[:3] decoder_output = tf.reshape(decoder_output, [ batch_size, rows, cols, 1, self._hparams.hidden_size ]) block_output = tf.reshape(block_output, [ batch_size, rows, cols, self._hparams.block_size, self._hparams.hidden_size ]) block_output = common_layers.layer_postprocess( decoder_output, block_output, self._hparams) return block_output
def decoder( decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, name="decoder", save_weights_to=None, make_image_summary=True, ): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. Returns: y: a Tensors """ x = decoder_input with tf.variable_scope(name): for layer in range(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer with tf.variable_scope(layer_name): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, make_image_summary=make_image_summary, ) utils.collect_named_outputs( "norms", "decoder_self_attention_%d" % (layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs( "norms", "decoder_self_attention_post_%d" % (layer), tf.norm(x, axis=-1)) if encoder_output is not None: with tf.variable_scope("encdec_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, save_weights_to=save_weights_to, make_image_summary=make_image_summary, ) utils.collect_named_outputs( "norms", "decoder_encoder_attention_%d" % (layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs( "norms", "decoder_encoder_attention_post_%d" % (layer), tf.norm(x, axis=-1)) with tf.variable_scope("ffn"): y = common_layers.dense_relu_dense( common_layers.layer_preprocess(x, hparams), hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout, ) utils.collect_named_outputs("norms", "decoder_ffn_%d" % (layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs( "norms", "decoder_ffn_post_%d" % (layer), tf.norm(x, axis=-1)) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_decoder(decoder_input, encoder_output, raw_targets, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=None, name="decoder"): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. name: a string Returns: y: a Tensors """ x = decoder_input with tf.variable_scope(name): sequence_length = usr_utils.get_length_from_raw( tf.squeeze(raw_targets, axis=[-2, -1])) # Used for RNNs sequence_length = sequence_length + 1 # Because of shifting raw_decoder_input = common_layers.shift_right(raw_targets) terminal_decoder_bias, nonterminal_decoder_bias = _get_t_nt_bias( raw_decoder_input, hparams, decoder_self_attention_bias) raw_decoder_input = tf.squeeze(raw_decoder_input, axis=[-2, -1]) pos_signals = generate_positional_signals(raw_decoder_input, hparams, terminal_decoder_bias, nonterminal_decoder_bias) pos_embeddings = generate_positional_embeddings( pos_signals, hparams.decoder_pos, hparams) attention_pos_embeddings = generate_positional_embeddings( pos_signals, hparams.decoder_attention_pos, hparams) if "sum" in hparams.pos_integration: x = x + pos_embeddings elif "ffn" in hparams.pos_integration: with tf.variable_scope("pos_ffn"): x = tf.concat([x, pos_embeddings], axis=2) x = transformer_ffn_layer(x, hparams) for layer in xrange(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): for layer_type in _iter_layer_types( hparams.decoder_layer_types, layer): if layer_type == "self_att": with tf.variable_scope("self_attention"): y = model_helper.multihead_attention_qkv( common_layers.layer_preprocess(x, hparams), None, None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams. decoder_self_attention_type, attention_order=hparams.attention_order, max_relative_position=hparams. max_relative_position, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "nt_self_att": with tf.variable_scope("nonterminal_self_attention"): y = model_helper.multihead_attention_qkv( common_layers.layer_preprocess(x, hparams), None, None, nonterminal_decoder_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams. decoder_self_attention_type, attention_order=hparams.attention_order, max_relative_position=hparams. max_relative_position, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "t_self_att": with tf.variable_scope("terminal_self_attention"): y = model_helper.multihead_attention_qkv( common_layers.layer_preprocess(x, hparams), None, None, terminal_decoder_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams. decoder_self_attention_type, attention_order=hparams.attention_order, max_relative_position=hparams. max_relative_position, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "rnn": with tf.variable_scope("recurrent"): y = transformer_rnn_layer( common_layers.layer_preprocess(x, hparams), sequence_length, hparams) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "pos_self_att" and attention_pos_embeddings is not None: with tf.variable_scope("pos_self_attention"): y = model_helper.multihead_attention_qkv( attention_pos_embeddings, # Query attention_pos_embeddings, # Key common_layers.layer_preprocess( x, hparams), # Value decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.pos_self_attention_type, attention_order=hparams.attention_order, max_relative_position=hparams. max_relative_position) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "enc_att" and encoder_output is not None: with tf.variable_scope("encdec_attention"): # TODO(llion): Add caching. y = model_helper.multihead_attention_qkv( common_layers.layer_preprocess(x, hparams), encoder_output, None, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) x = common_layers.layer_postprocess(x, y, hparams) else: tf.logging.warn( "Ignoring '%s' in decoder_layer_types" % layer_type) with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def evolved_transformer_decoder(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=None, decode_loop_step=None, name="decoder", nonpadding=None, save_weights_to=None, make_image_summary=True, losses=None): """Evolved Transformer decoder. See arxiv.org/abs/1901.11117 for more details. Args: decoder_input: a Tensor. encoder_output: a Tensor. decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()). encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()). hparams: hyperparameters for model. cache: dict, containing tensors which are the results of previous layers, used for fast decoding. decode_loop_step: An integer, step number of the decoding loop. Only used for inference on TPU. name: a string. nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This is used to mask out padding in convolutional layers. We generally only need this mask for "packed" datasets, because for ordinary datasets, no padding is ever followed by nonpadding. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. losses: Not supported. Returns: Decoder output tensor. """ del losses num_trainable_top_decoder_layers = hparams.get( "num_trainable_top_decoder_layers", -1) # -1 means train all weights. if num_trainable_top_decoder_layers >= 0: encoder_output = tf.stop_gradient(encoder_output) attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name): hidden_state = decoder_input num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers for layer in range(num_layers): if num_trainable_top_decoder_layers == num_layers - layer: hidden_state = tf.stop_gradient(hidden_state) layer_name = "layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): with tf.variable_scope(_SIXTEEN_HEAD_ATTENTION_NAME): residual_state = hidden_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) attention_cache = layer_cache[ _SIXTEEN_HEAD_ATTENTION_NAME] if layer_cache is not None else None left_state = common_attention.multihead_attention( hidden_state, None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, _capped_double_heads(hparams.num_heads), hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams.add_relative_to_values, save_weights_to=save_weights_to, cache=attention_cache, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), decode_loop_step=decode_loop_step, vars_3d=hparams.get("attention_variables_3d"), activation_dtype=hparams.get("activation_dtype", "float32"), weight_dtype=hparams.get("weight_dtype", "float32")) if encoder_output is not None: with tf.variable_scope(_FIRST_ATTEND_TO_ENCODER_NAME): attention_cache = ( layer_cache[_FIRST_ATTEND_TO_ENCODER_NAME] if layer_cache is not None else None) right_state = common_attention.multihead_attention( hidden_state, encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, max_relative_position=hparams. max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams. add_relative_to_values, save_weights_to=save_weights_to, cache=attention_cache, make_image_summary=make_image_summary, dropout_broadcast_dims= attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d"), activation_dtype=hparams.get( "activation_dtype", "float32"), weight_dtype=hparams.get("weight_dtype", "float32")) left_state = tf.nn.dropout( left_state, 1 - hparams.layer_prepostprocess_dropout) right_state = tf.nn.dropout( right_state, 1 - hparams.layer_prepostprocess_dropout) hidden_state = residual_state + left_state + right_state else: hidden_state = common_layers.layer_postprocess( residual_state, left_state, hparams) with tf.variable_scope(_CONV_BRANCHES_NAME): residual_state = hidden_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) if nonpadding is not None: # Mask padding from conv layers. mask = tf.tile(tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size]) hidden_state *= mask if layer_cache: if decode_loop_step is None: hidden_state = layer_cache[ _CONV_BRANCHES_FIRST_LAYER_NAME] = tf.concat( [ layer_cache[ _CONV_BRANCHES_FIRST_LAYER_NAME], hidden_state ], axis=1)[:, -1 * _DECODER_LEFT_CONV_PADDING - 1:, :] left_state = hidden_state right_state = hidden_state[:, _DECODER_LEFT_CONV_PADDING - _DECODER_RIGHT_CONV_PADDING:, :] else: # Inplace update is required for inference on TPU. # Inplace_ops only supports inplace_update on the first dimension. tmp = tf.transpose( layer_cache[_CONV_BRANCHES_FIRST_LAYER_NAME], perm=[1, 0, 2]) tmp = tf.expand_dims(tmp, axis=1) tmp = inplace_ops.alias_inplace_update( tmp, decode_loop_step * tf.shape(hidden_state)[1] + _DECODER_LEFT_CONV_PADDING, tf.transpose(hidden_state, perm=[1, 0, 2])) tmp = tf.squeeze(tmp, axis=1) hidden_state = layer_cache[ _CONV_BRANCHES_FIRST_LAYER_NAME] = tf.transpose( tmp, perm=[1, 0, 2]) batch_size = hidden_state.shape.as_list()[0] left_state = tf.slice( hidden_state, [0, decode_loop_step, 0], [ batch_size, _DECODER_LEFT_CONV_PADDING + 1, hparams.hidden_size ]) right_state = tf.slice(hidden_state, [ 0, decode_loop_step + _DECODER_LEFT_CONV_PADDING - _DECODER_RIGHT_CONV_PADDING, 0 ], [ batch_size, _DECODER_RIGHT_CONV_PADDING + 1, hparams.hidden_size ]) else: # No caching. left_state = tf.pad( hidden_state, paddings=[[0, 0], [_DECODER_LEFT_CONV_PADDING, 0], [0, 0]]) right_state = tf.pad( hidden_state, paddings=[[0, 0], [_DECODER_RIGHT_CONV_PADDING, 0], [0, 0]]) left_output_dim = int(hparams.hidden_size * 2) separable_conv_11x1 = tf.layers.SeparableConv1D( left_output_dim, 11, padding="VALID", name="separable_conv11x1", activation=tf.nn.relu) left_state = separable_conv_11x1.apply(left_state) left_state = tf.nn.dropout( left_state, 1 - hparams.layer_prepostprocess_dropout) right_output_dim = int(hparams.hidden_size / 2) separable_conv_7x1_1 = tf.layers.SeparableConv1D( right_output_dim, 7, padding="VALID", name="separable_conv_7x1_1") right_state = separable_conv_7x1_1.apply(right_state) right_state = tf.nn.dropout( right_state, 1 - hparams.layer_prepostprocess_dropout) right_state = tf.pad( right_state, [[0, 0], [0, 0], [0, left_output_dim - right_output_dim]], constant_values=0) hidden_state = left_state + right_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) if nonpadding is not None: # Mask padding from conv layers. mask = tf.tile(tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size * 2]) hidden_state *= mask if layer_cache: if decode_loop_step is None: hidden_state = layer_cache[ _CONV_BRANCHES_SECOND_LAYER_NAME] = tf.concat( [ layer_cache[ _CONV_BRANCHES_SECOND_LAYER_NAME], hidden_state ], axis=1)[:, -1 * _DECODER_FINAL_CONV_PADDING - 1:, :] else: # Inplace update is required for inference on TPU. # Inplace_ops only supports inplace_update on the first dimension. tmp = tf.transpose( layer_cache[_CONV_BRANCHES_SECOND_LAYER_NAME], perm=[1, 0, 2]) tmp = tf.expand_dims(tmp, axis=1) tmp = inplace_ops.alias_inplace_update( tmp, (decode_loop_step + _DECODER_FINAL_CONV_PADDING) * tf.shape(hidden_state)[1], tf.transpose(hidden_state, perm=[1, 0, 2])) tmp = tf.squeeze(tmp, axis=1) hidden_state = layer_cache[ _CONV_BRANCHES_SECOND_LAYER_NAME] = tf.transpose( tmp, perm=[1, 0, 2]) batch_size = hidden_state.shape.as_list()[0] hidden_state = tf.slice( hidden_state, [0, decode_loop_step, 0], [ batch_size, _DECODER_FINAL_CONV_PADDING + 1, hparams.hidden_size * 2 ]) else: hidden_state = tf.pad( hidden_state, paddings=[[0, 0], [_DECODER_FINAL_CONV_PADDING, 0], [0, 0]]) separable_conv_7x1_2 = tf.layers.SeparableConv1D( hparams.hidden_size, 7, padding="VALID", name="separable_conv_7x1_2") hidden_state = separable_conv_7x1_2.apply(hidden_state) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) with tf.variable_scope(_VANILLA_ATTENTION_NAME): residual_state = hidden_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) attention_cache = layer_cache[ _VANILLA_ATTENTION_NAME] if layer_cache is not None else None hidden_state = common_attention.multihead_attention( hidden_state, None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams.add_relative_to_values, save_weights_to=save_weights_to, cache=attention_cache, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), decode_loop_step=decode_loop_step, vars_3d=hparams.get("attention_variables_3d"), activation_dtype=hparams.get("activation_dtype", "float32"), weight_dtype=hparams.get("weight_dtype", "float32")) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) if encoder_output is not None: with tf.variable_scope(_SECOND_ATTEND_TO_ENCODER_NAME): residual_state = hidden_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) attention_cache = ( layer_cache[_SECOND_ATTEND_TO_ENCODER_NAME] if layer_cache is not None else None) hidden_state = common_attention.multihead_attention( hidden_state, encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, max_relative_position=hparams. max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams. add_relative_to_values, save_weights_to=save_weights_to, cache=attention_cache, make_image_summary=make_image_summary, dropout_broadcast_dims= attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d"), activation_dtype=hparams.get( "activation_dtype", "float32"), weight_dtype=hparams.get("weight_dtype", "float32")) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) with tf.variable_scope("dense_layers"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess( hidden_state, hparams) hidden_state = tf.layers.dense(hidden_state, int(hparams.hidden_size * 4), activation=tf.nn.swish) hidden_state = tf.nn.dropout( hidden_state, 1 - hparams.layer_prepostprocess_dropout) hidden_state = common_layers.layer_preprocess( hidden_state, hparams) hidden_state = tf.layers.dense(hidden_state, hparams.hidden_size) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) decoder_output = common_layers.layer_preprocess(hidden_state, hparams) if num_trainable_top_decoder_layers == 0: decoder_output = tf.stop_gradient(decoder_output) return decoder_output
def transformer_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder", nonpadding=None, save_weights_to=None, make_image_summary=True, losses=None): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This must either be passed in, which we do for "packed" datasets, or inferred from encoder_self_attention_bias. The knowledge about padding is used for pad_remover(efficiency) and to mask out padding in convolutional layers. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. losses: optional list onto which to append extra training losses Returns: y: a Tensors """ x = encoder_input attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name): if nonpadding is not None: padding = 1.0 - nonpadding else: padding = common_attention.attention_bias_to_padding( encoder_self_attention_bias) nonpadding = 1.0 - padding pad_remover = None if hparams.use_pad_remover and not common_layers.is_on_tpu(): pad_remover = expert_utils.PadRemover(padding) for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): # sg: imdb comments y = common_attention.multihead_attention( common_layers.layer_preprocess( x, hparams), # added layer norm None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, # 128 hparams.attention_value_channels or hparams.hidden_size, # 128 hparams.hidden_size, # 128 hparams.num_heads, # 4 hparams.attention_dropout, # 0.1 attention_type=hparams. self_attention_type, # 'dot_product' save_weights_to=save_weights_to, max_relative_position=hparams. max_relative_position, # 0 make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length")) # 256 x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer(common_layers.layer_preprocess( x, hparams), hparams, pad_remover, conv_padding="SAME", nonpadding_mask=nonpadding, losses=losses) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def decoder(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, name="decoder", save_weights_to=None, make_image_summary=True,): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. Returns: y: a Tensors """ x = decoder_input with tf.variable_scope(name): for layer in range(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer with tf.variable_scope(layer_name): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, make_image_summary=make_image_summary, ) utils.collect_named_outputs("norms", "decoder_self_attention_%d"%(layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs("norms", "decoder_self_attention_post_%d"%(layer), tf.norm(x, axis=-1)) if encoder_output is not None: with tf.variable_scope("encdec_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, save_weights_to=save_weights_to, make_image_summary=make_image_summary, ) utils.collect_named_outputs( "norms", "decoder_encoder_attention_%d"%(layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs( "norms", "decoder_encoder_attention_post_%d"%(layer), tf.norm(x, axis=-1)) with tf.variable_scope("ffn"): y = common_layers.dense_relu_dense( common_layers.layer_preprocess(x, hparams), hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout, ) utils.collect_named_outputs("norms", "decoder_ffn_%d"%(layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs("norms", "decoder_ffn_post_%d"%(layer), tf.norm(x, axis=-1)) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder", nonpadding=None, save_weights_to=None, make_image_summary=True): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This must either be passed in, which we do for "packed" datasets, or inferred from encoder_self_attention_bias. The knowledge about padding is used for pad_remover(efficiency) and to mask out padding in convoltutional layers. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. Returns: y: a Tensors """ x = encoder_input attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name): if nonpadding is not None: padding = 1.0 - nonpadding else: padding = common_attention.attention_bias_to_padding( encoder_self_attention_bias) nonpadding = 1.0 - padding pad_remover = None if hparams.use_pad_remover and not common_layers.is_on_tpu(): pad_remover = expert_utils.PadRemover(padding) for layer in xrange(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, max_relative_position=hparams.max_relative_position, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, pad_remover, conv_padding="SAME", nonpadding_mask=nonpadding) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def invertible_transformer_decoder_attention_unit( x, hparams, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, attention_dropout_broadcast_dims, save_weights_to=None, make_image_summary=True, split_index=0): """Applies multihead attention function which is parametrised for decoding. Args: x: input (decoder input) hparams: model hyper-parameters encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] decoder_self_attention_bias: Bias and mask weights for decoder self-attention. [batch_size, decoder_length] encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder attention. [batch_size, input_length] attention_dropout_broadcast_dims: Fpr noise broadcasting in the dropout layers to save memory during training save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. Returns: The output tensor """ ################## ## CHANGE START ## ################## with tf.variable_scope("self_attention"): # ERROR: Output is 128 instead of 64 x_splits = tf.split(x, num_or_size_splits=2, axis=2) y = common_attention.multihead_attention( common_layers.layer_preprocess(x_splits[split_index], hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, max_relative_position=hparams.max_relative_position, cache=None, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, hard_attention_k=hparams.hard_attention_k) x_splits[1 - split_index] = common_layers.layer_postprocess( x_splits[1 - split_index], y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x_splits[split_index], hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, hard_attention_k=hparams.hard_attention_k) x_splits[1 - split_index] = common_layers.layer_postprocess( x_splits[1 - split_index], y, hparams) x = tf.concat(x_splits, axis=2) ################## ## CHANGE END ## ################## return x
def transformer_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder", nonpadding=None, save_weights_to=None, make_image_summary=True, losses=None): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This must either be passed in, which we do for "packed" datasets, or inferred from encoder_self_attention_bias. The knowledge about padding is used for pad_remover(efficiency) and to mask out padding in convolutional layers. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. losses: optional list onto which to append extra training losses Returns: y: a Tensors """ x = encoder_input attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_NUM_HIDDEN_LAYERS, value=hparams.num_encoder_layers or hparams.num_hidden_layers) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_ATTENTION_DROPOUT, value=hparams.attention_dropout) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_ATTENTION_DENSE, value={ "use_bias": "false", "num_heads": hparams.num_heads, "hidden_size": hparams.hidden_size }) with tf.variable_scope(name): if nonpadding is not None: padding = 1.0 - nonpadding else: padding = common_attention.attention_bias_to_padding( encoder_self_attention_bias) nonpadding = 1.0 - padding pad_remover = None if hparams.use_pad_remover and not common_layers.is_xla_compiled(): pad_remover = expert_utils.PadRemover(padding) for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams.add_relative_to_values, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d")) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, pad_remover, conv_padding="SAME", nonpadding_mask=nonpadding, losses=losses) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_NORM, value={"hidden_size": hparams.hidden_size}) return common_layers.layer_preprocess(x, hparams)