def transformer_encoder_layers(inputs, num_layers, hparams, attention_type=AttentionType.GLOBAL, self_attention_bias=None, q_padding="VALID", kv_padding="VALID", name="transformer"): """Multi layer transformer encoder.""" x = inputs x = tf.nn.dropout(x, 1.0 - hparams.layer_prepostprocess_dropout) for layer in range(num_layers): # attention layers + skip connections with tf.variable_scope("%s_layer_%d" % (name, layer)): if attention_type == AttentionType.LOCAL_2D: y = local_attention_2d(common_layers.layer_preprocess(x, hparams), hparams, attention_type="local_attention_2d") elif attention_type == AttentionType.LOCAL_1D: y = local_attention_1d(common_layers.layer_preprocess(x, hparams), hparams, attention_type="local_unmasked", q_padding=q_padding, kv_padding=kv_padding) elif attention_type == AttentionType.GLOBAL: y = full_self_attention(common_layers.layer_preprocess(x, hparams), self_attention_bias, hparams, q_padding=q_padding, kv_padding=kv_padding) x = common_layers.layer_postprocess(x, y, hparams) # feed-fwd layer + skip connections y = ffn_layer(common_layers.layer_preprocess(x, hparams), hparams) x = common_layers.layer_postprocess(x, y, hparams) return common_layers.layer_preprocess(x, hparams)
def f(x, side_input): """f(x) for reversible layer, self-attention and enc-dec attention.""" decoder_self_attention_bias = side_input[0] encoder_decoder_attention_bias = side_input[1] encoder_output = side_input[2] old_hid_size = hparams.hidden_size hparams.hidden_size = old_hid_size // 2 with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess( x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) y = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess( x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) y = common_layers.layer_postprocess(x, y, hparams) hparams.hidden_size = old_hid_size return y
def image_encoder(image_feat, hparams, name="image_encoder", save_weights_to=None, make_image_summary=True): """A stack of self attention layers.""" x = image_feat image_hidden_size = hparams.image_hidden_size or hparams.hidden_size image_filter_size = hparams.image_filter_size or hparams.filter_size with tf.variable_scope(name): for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = vqa_layers.multihead_attention( common_layers.layer_preprocess(x, hparams), None, None, hparams.attention_key_channels or image_hidden_size, hparams.attention_value_channels or image_hidden_size, image_hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.image_self_attention_type, save_weights_to=save_weights_to, make_image_summary=make_image_summary, scale_dotproduct=hparams.scale_dotproduct, ) utils.collect_named_outputs( "norms", "image_feat_self_attention_%d"%(layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs( "norms", "image_feat_self_attention_postprocess_%d"%(layer), tf.norm(x, axis=-1)) with tf.variable_scope("ffn"): y = common_layers.dense_relu_dense( common_layers.layer_preprocess(x, hparams), image_filter_size, image_hidden_size, dropout=hparams.relu_dropout, ) utils.collect_named_outputs( "norms", "image_feat_ffn_%d"%(layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs( "norms", "image_feat_ffn_postprocess_%d"%(layer), tf.norm(x, axis=-1)) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def local_within_block_attention(x, self_attention_bias, hparams, attention_type="local_within_block_mask_right", q_padding="VALID", kv_padding="VALID"): """Local within block self attention.""" x_new, x_shape, is_4d = maybe_reshape_4d_to_3d(x) with tf.variable_scope("local_within_block"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x_new, hparams), None, self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=attention_type, block_width=hparams.block_width, block_length=hparams.block_length, q_padding=q_padding, kv_padding=kv_padding, q_filter_width=hparams.q_filter_width, kv_filter_width=hparams.kv_filter_width, name="local_within_block") if is_4d: y = tf.reshape(y, x_shape) return y
def precompute_edge_matrices(adjacency, hparams): """Precompute the a_in and a_out tensors. (we don't want to add to the graph everytime _fprop is called) Args: adjacency: placeholder of real valued vectors of shape [B, L, L, E] hparams: tf.HParams object Returns: edge_matrices: [batch, L * D, L * D] the dense matrix for message passing viewed as a block matrix (L,L) blocks of size (D,D). Each plot is a function of the edge vector of the adjacency matrix at that spot. """ batch_size, num_nodes, _, edge_dim = common_layers.shape_list(adjacency) # build the edge_network for incoming edges with tf.variable_scope("edge_network"): x = tf.reshape( adjacency, [batch_size * num_nodes * num_nodes, edge_dim], name="adj_reshape_in") for ip_layer in range(hparams.edge_network_layers): name = "edge_network_layer_%d"%ip_layer x = tf.layers.dense(common_layers.layer_preprocess(x, hparams), hparams.edge_network_hidden_size, activation=tf.nn.relu, name=name) x = tf.layers.dense(common_layers.layer_preprocess(x, hparams), hparams.hidden_size**2, activation=None, name="edge_network_output") # x = [batch * l * l, d *d] edge_matrices_flat = tf.reshape(x, [batch_size, num_nodes, num_nodes, hparams.hidden_size, hparams.hidden_size]) # reshape to [batch, l * d, l *d] edge_matrices = tf.reshape( tf.transpose(edge_matrices_flat, [0, 1, 3, 2, 4]), [ -1, num_nodes * hparams.hidden_size, num_nodes * hparams.hidden_size ], name="edge_matrices") return edge_matrices
def transformer_layers_sharded(dp, ps_devices, inputs, num_layers, hparams, self_attention_bias=None, enc_output=None, attention_type=AttentionType.GLOBAL, name="transformer"): """Multi layer transformer, sharded by the data parallelism dp.""" x = inputs extra_loss = tf.constant(0.0) moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] expert_fn = expert_utils.ffn_expert_fn( hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) for layer in range(num_layers): with tf.variable_scope("%s_layer_%d" % (name, layer)): # self-attention if attention_type == AttentionType.LOCAL_2D: y = dp(local_attention_2d(common_layers.layer_preprocess(x, hparams), hparams, attention_type="masked_local_attention_2d")) elif attention_type == AttentionType.LOCAL_1D: y = dp(local_attention_1d(common_layers.layer_preprocess(x, hparams), hparams, attention_type="local_mask_right", q_padding="LEFT", kv_padding="LEFT")) elif attention_type == AttentionType.GLOCAL: y = dp(local_global_attention( common_layers.layer_preprocess(x, hparams), self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT")) elif attention_type == AttentionType.GLOBAL: self_attention_bias = dp(get_self_attention_bias(x)) y = dp(full_self_attention(common_layers.layer_preprocess(x, hparams), self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT")) x = common_layers.layer_postprocess(x, y, hparams) if enc_output is not None: y = dp(encdec_attention_1d(common_layers.layer_preprocess(x, hparams), enc_output, None, hparams)) x = dp(common_layers.layer_postprocess, x, y, hparams) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers_decoder.split(","): y, loss = expert_utils.distributed_moe( dp, ps_devices, common_layers.layer_preprocess(x, hparams), hparams.mode == tf.estimator.ModeKeys.TRAIN, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_loss += loss x = dp(common_layers.layer_postprocess, x, y, hparams) else: y = dp(ffn_layer, common_layers.layer_preprocess(x, hparams), hparams) x = dp(common_layers.layer_postprocess, x, y, hparams) return dp(common_layers.layer_preprocess, x, hparams), extra_loss
def g(x): """g(x) for reversible layer, feed-forward layer.""" old_hid_size = hparams.hidden_size hparams.hidden_size = old_hid_size // 2 with tf.variable_scope("ffn"): y = transformer.transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) y = common_layers.layer_postprocess(x, y, hparams) hparams.hidden_size = old_hid_size return y
def compress_self_attention_layer(x, hparams, name=None): """Attend function.""" with tf.variable_scope(name, default_name="compress_self_attention"): x, xshape, _ = cia.maybe_reshape_4d_to_3d(x) y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) res = common_layers.layer_postprocess(x, y, hparams) return tf.reshape(res, xshape)
def attend(x, source, hparams, name): with tf.variable_scope(name): x = tf.squeeze(x, axis=2) if len(source.get_shape()) > 3: source = tf.squeeze(source, axis=2) source = common_attention.add_timing_signal_1d(source) y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), source, None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) res = common_layers.layer_postprocess(x, y, hparams) return tf.expand_dims(res, axis=2)
def transformer_decoder_layers(inputs, encoder_output, bias, num_layers, hparams, attention_type=AttentionType.LOCAL_2D, name="transformer"): """Multi layer transformer.""" x = inputs x = tf.nn.dropout(x, 1.0 - hparams.layer_prepostprocess_dropout) if attention_type == AttentionType.DILATED: assert len(hparams.gap_sizes) == num_layers for layer in xrange(num_layers): with tf.variable_scope("%s_layer_%d" % (name, layer)): # self-attention + skip connections if attention_type == AttentionType.LOCAL_2D: y = local_attention_2d(common_layers.layer_preprocess(x, hparams), hparams, attention_type="masked_local_attention_2d") elif attention_type == AttentionType.LOCAL_1D: y = local_attention_1d(common_layers.layer_preprocess(x, hparams), bias, hparams, attention_type="local_mask_right", q_padding="LEFT", kv_padding="LEFT") elif attention_type == AttentionType.GLOCAL: y = local_global_attention(common_layers.layer_preprocess(x, hparams), bias, hparams, q_padding="LEFT", kv_padding="LEFT") elif attention_type == AttentionType.DILATED: y = dilated_attention_1d(common_layers.layer_preprocess(x, hparams), bias, hparams, q_padding="LEFT", kv_padding="LEFT", gap_size=hparams.gap_sizes[layer]) elif attention_type == AttentionType.GLOBAL: y = full_self_attention(common_layers.layer_preprocess(x, hparams), bias, hparams, q_padding="LEFT", kv_padding="LEFT") x = common_layers.layer_postprocess(x, y, hparams) # enc-dec attention + skip connections if encoder_output is not None: y = encdec_attention_1d(common_layers.layer_preprocess(x, hparams), encoder_output, hparams) x = common_layers.layer_postprocess(x, y, hparams) # feed-fwd layers + skip connections y = ffn_layer(common_layers.layer_preprocess(x, hparams), hparams) x = common_layers.layer_postprocess(x, y, hparams) return common_layers.layer_preprocess(x, hparams)
def attention_lm_decoder(decoder_input, decoder_self_attention_bias, hparams, name="decoder"): """A stack of attention_lm layers. Args: decoder_input: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string Returns: y: a Tensors """ x = decoder_input with tf.variable_scope(name): for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess( x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = common_layers.conv_hidden_relu( common_layers.layer_preprocess(x, hparams), hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout) x = common_layers.layer_postprocess(x, y, hparams) return common_layers.layer_preprocess(x, hparams)
def attend(x, source, hparams, name): """Attend function.""" with tf.variable_scope(name): # x = tf.squeeze(x, axis=2) x, xshape, _ = cia.maybe_reshape_4d_to_3d(x) if len(source.get_shape()) > 3: source = tf.squeeze(source, axis=2) source = common_attention.add_timing_signal_1d(source) y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), source, None, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) res = common_layers.layer_postprocess(x, y, hparams) return tf.reshape(res, xshape)
def top(self, body_output, _): with tf.variable_scope(self.name): hidden_dim = self._model_hparams.hidden_size img_len = self._model_hparams.img_len channels = self.num_channels # RGB batch = common_layers.shape_list(body_output)[0] x = tf.layers.conv2d( body_output, hidden_dim * channels, (1, 1), strides=(1, 1), padding="VALID", activation=tf.nn.relu, name="decompress_conv") x = tf.reshape(x, [batch, img_len, img_len * channels, hidden_dim]) x = common_layers.layer_preprocess(x, self._model_hparams) x = tf.layers.dense( x, 256, use_bias=True, activation=None, name="output_conv") x = tf.reshape(x, [-1, img_len, img_len, channels, self.top_dimensionality]) return x
def top(self, body_output, _): with tf.variable_scope(self.name): hidden_dim = self._model_hparams.hidden_size img_len = self._model_hparams.img_len channels = self.num_channels # RGB batch = common_layers.shape_list(body_output)[0] x = tf.layers.conv2d( body_output, hidden_dim*channels, (1, 1), strides=(1, 1), padding="VALID", activation=tf.nn.relu, name="decompress_conv") x = tf.reshape(x, [batch, img_len, img_len * channels, hidden_dim]) x = common_layers.layer_preprocess(x, self._model_hparams) x = tf.layers.dense(x, 256, use_bias=True, activation=None, name="output_conv") x = tf.reshape(x, [-1, img_len, img_len, channels, self.top_dimensionality]) return x
def _two_attn_unit(x): with tf.variable_scope("delibctx_attention"): preprocess_x = common_layers.layer_preprocess(x, hparams) y = common_attention.multihead_attention( preprocess_x, encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, name="encdec_attention") + \ common_attention.multihead_attention( preprocess_x, firstP_input, firstP_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, name="decdec_attention" ) x = common_layers.layer_postprocess(x, y, hparams) return x
def moe_transformer_decoder_layer(decoder_input, decoder_self_attention_bias, layer_idx, hparams, encoder_output=None, encoder_decoder_attention_bias=None, cache=None, decode_loop_step=None, save_weights_to=None, make_image_summary=False, layer_collection=None, recurrent_memory_by_layer=None, chunk_number=None): """A single transformer decoder layer with MoE module.""" x, _ = transformer.transformer_self_attention_layer( decoder_input=decoder_input, decoder_self_attention_bias=decoder_self_attention_bias, layer_idx=layer_idx, hparams=hparams, encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias, cache=cache, decode_loop_step=decode_loop_step, save_weights_to=save_weights_to, make_image_summary=make_image_summary, layer_collection=layer_collection, recurrent_memory_by_layer=recurrent_memory_by_layer, chunk_number=chunk_number) layer = layer_idx layer_name = "layer_%d" % layer with tf.variable_scope(layer_name): with tf.variable_scope("ffn"): y, _ = transformer_ffn_layer( common_layers.layer_preprocess( x, hparams), hparams) x = common_layers.layer_postprocess(x, y, hparams) return x
def after2daggregate(encoder_input, encoder_self_attention_bias_slices, hparams, name): x = encoder_input with tf.variable_scope(name): pad_bias_combined = get_pad_remover( hparams, encoder_self_attention_bias_slices, is_combined=True) x_slices = [] for i in range(x.get_shape()[2].value): with tf.variable_scope("encoder_lex" + str(i)): x_slice = x[:, :, i, :] pad_bias = get_pad_remover( hparams, [encoder_self_attention_bias_slices[i]]) for layer in xrange(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): x_slice = attn_over_sent(x_slice, pad_bias[0], pad_bias[1], hparams) x_slices.append(x_slice) x = tf.stack(x_slices, 2) x = attn_over_sent_and_lex_2d(x, pad_bias_combined[0], hparams) x = lex_aggregate(x, hparams) return common_layers.layer_preprocess(x, hparams)
def transformer_decoder_layers(inputs, encoder_output, bias, num_layers, hparams, attention_type=AttentionType.LOCAL_2D, name="transformer"): """Multi layer transformer.""" x = inputs x = tf.nn.dropout(x, 1.0 - hparams.layer_prepostprocess_dropout) for layer in xrange(num_layers): with tf.variable_scope("%s_layer_%d" % (name, layer)): # self-attention + skip connections if attention_type == AttentionType.LOCAL_2D: y = local_attention_2d(common_layers.layer_preprocess(x, hparams), hparams, attention_type="masked_local_attention_2d") elif attention_type == AttentionType.LOCAL_1D: y = local_attention_1d(common_layers.layer_preprocess(x, hparams), bias, hparams, attention_type="local_mask_right", q_padding="LEFT", kv_padding="LEFT") elif attention_type == AttentionType.GLOCAL: y = local_global_attention(common_layers.layer_preprocess(x, hparams), bias, hparams, q_padding="LEFT", kv_padding="LEFT") elif attention_type == AttentionType.GLOBAL: y = full_self_attention(common_layers.layer_preprocess(x, hparams), bias, hparams, q_padding="LEFT", kv_padding="LEFT") # TODO(nikip): Add support for dilated attention. x = common_layers.layer_postprocess(x, y, hparams) # enc-dec attention + skip connections if encoder_output is not None: y = encdec_attention_1d(common_layers.layer_preprocess(x, hparams), encoder_output, hparams) x = common_layers.layer_postprocess(x, y, hparams) # feed-fwd layers + skip connections y = ffn_layer(common_layers.layer_preprocess(x, hparams), hparams) x = common_layers.layer_postprocess(x, y, hparams) return common_layers.layer_preprocess(x, hparams)
def transformer_edit_ops_layer( decoder_input, hparams, encoder_output, features, cache=None, decode_loop_step=None, nonpadding=None, losses=None, layer_collection=None, ): """Layer that conditions on the error tag and start and end token pointers.""" if isinstance(encoder_output, list): # Select forward encoder encoder_output = encoder_output[0] with tf.variable_scope('edit_ops_layer'): with tf.variable_scope('ffn'): x = decoder_input # Shorthand for layer preprocessing # pylint: disable=g-long-lambda preproc = lambda z: common_layers.layer_preprocess( z, hparams, layer_collection=layer_collection) # pylint: enable=g-long-lambda layer_inputs = [preproc(x)] error_tags = common_layers.shift_right_3d( common_layers.flatten4d3d(features['targets_error_tag'])) layer_inputs.append(preproc(error_tags)) y = transformer_layers.transformer_ffn_layer( tf.concat(layer_inputs, axis=2), hparams, conv_padding='LEFT', nonpadding_mask=nonpadding, losses=losses, cache=cache, decode_loop_step=decode_loop_step, layer_collection=layer_collection, ) x = common_layers.layer_postprocess(x, y, hparams) return x
def transformer_between_predictions_layer(x, hparams, name, cache=None, decode_loop_step=None, nonpadding=None, losses=None, layer_collection=None): """Stack between prediction layers.""" with tf.variable_scope(name): for i in range(hparams.ffn_in_prediction_cascade): with tf.variable_scope("layer_%d" % i): y = transformer_layers.transformer_ffn_layer( common_layers.layer_preprocess( x, hparams, layer_collection=layer_collection), hparams, conv_padding="LEFT", nonpadding_mask=nonpadding, losses=losses, cache=cache, decode_loop_step=decode_loop_step, layer_collection=layer_collection) x = common_layers.layer_postprocess(x, y, hparams) return x
def recurrent_transformer_decoder( decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, name="decoder", nonpadding=None, save_weights_to=None, make_image_summary=True): """Recurrent decoder function.""" x = decoder_input attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name): ffn_unit = functools.partial( # use encoder ffn, since decoder ffn use left padding universal_transformer_util.transformer_encoder_ffn_unit, hparams=hparams, nonpadding_mask=nonpadding) attention_unit = functools.partial( universal_transformer_util.transformer_decoder_attention_unit, hparams=hparams, encoder_output=encoder_output, decoder_self_attention_bias=decoder_self_attention_bias, encoder_decoder_attention_bias=encoder_decoder_attention_bias, attention_dropout_broadcast_dims=attention_dropout_broadcast_dims, save_weights_to=save_weights_to, make_image_summary=make_image_summary) x, extra_output = universal_transformer_util.universal_transformer_layer( x, hparams, ffn_unit, attention_unit) return common_layers.layer_preprocess(x, hparams), extra_output
def recurrent_transformer_decoder( decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, name="decoder", nonpadding=None, save_weights_to=None, make_image_summary=True): """Recurrent decoder function.""" x = decoder_input attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name): ffn_unit = functools.partial( # use encoder ffn, since decoder ffn use left padding universal_transformer_util.transformer_encoder_ffn_unit, hparams=hparams, nonpadding_mask=nonpadding) attention_unit = functools.partial( universal_transformer_util.transformer_decoder_attention_unit, hparams=hparams, encoder_output=encoder_output, decoder_self_attention_bias=decoder_self_attention_bias, encoder_decoder_attention_bias=encoder_decoder_attention_bias, attention_dropout_broadcast_dims=attention_dropout_broadcast_dims, save_weights_to=save_weights_to, make_image_summary=make_image_summary) x, extra_output = universal_transformer_util.universal_transformer_layer( x, hparams, ffn_unit, attention_unit) return common_layers.layer_preprocess(x, hparams), extra_output
def transformer_decoder(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=None, name="decoder", terminal_decoder_bias=None, nonterminal_decoder_bias=None, pop_decoder_bias=None, nonpadding=None, decoder_input_raw=None, pos_signals=None): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This is used to mask out padding in convoltutional layers. We generally only need this mask for "packed" datasets, because for ordinary datasets, no padding is ever followed by nonpadding. Returns: y: a Tensors """ x = decoder_input sequence_length = usr_utils.get_length_from_nonpadding(nonpadding) with tf.variable_scope(name): for layer in xrange(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): for layer_type in _iter_layer_types( hparams.decoder_layer_types, layer): if layer_type == "self_att": with tf.variable_scope("self_attention"): y = model_helper.multihead_attention_qkv( common_layers.layer_preprocess(x, hparams), None, None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams. decoder_self_attention_type, attention_order=hparams.attention_order, max_relative_position=hparams. max_relative_position, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "nt_self_att": with tf.variable_scope("nonterminal_self_attention"): y = model_helper.multihead_attention_qkv( common_layers.layer_preprocess(x, hparams), None, None, nonterminal_decoder_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams. decoder_self_attention_type, attention_order=hparams.attention_order, max_relative_position=hparams. max_relative_position, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "t_self_att": with tf.variable_scope("terminal_self_attention"): y = model_helper.multihead_attention_qkv( common_layers.layer_preprocess(x, hparams), None, None, terminal_decoder_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams. decoder_self_attention_type, attention_order=hparams.attention_order, max_relative_position=hparams. max_relative_position, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "osm_self_att": with tf.variable_scope("osm_self_attention"): y = model_helper.multihead_attention_osm( common_layers.layer_preprocess(x, hparams), terminal_decoder_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, query_antecedent_raw=decoder_input_raw) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "parent_ffn": with tf.variable_scope("parent_ffn"): parent_pointers = tf.cast( pos_signals["parent_timing"], tf.int32) parent_x = usr_utils.gather_2d(x, parent_pointers) x = tf.concat([x, parent_x], axis=2) x = transformer_ffn_layer(x, hparams, conv_padding="LEFT") elif layer_type == "enc_pop_att": with tf.variable_scope("enc_pop_att"): enc_x = usr_utils.expand_memory_by_pop( pop_decoder_bias, encoder_output, offset=0) x = tf.concat([x, enc_x], axis=2) x = transformer_ffn_layer(x, hparams, conv_padding="LEFT") elif layer_type == "rnn": with tf.variable_scope("recurrent"): y = transformer_rnn_layer( common_layers.layer_preprocess(x, hparams), sequence_length, hparams) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "enc_pop_pos_att" and encoder_output is not None: with tf.variable_scope("encdec_pop_pos_attention"): src_pos = tf.cumsum(tf.cast( pop_decoder_bias, tf.float32), axis=1) src_pos_embed = embed_position_signal_sine( src_pos, hparams.hidden_size // 2, max_timescale=5.0e3) x_with_src_pos = tf.concat([ common_layers.layer_preprocess(x, hparams), src_pos_embed ], 2) x_with_src_pos.set_shape([ None, None, hparams.hidden_size + hparams.hidden_size // 2 ]) # TODO Add src pos embedding to x_with_src_pos y = model_helper.multihead_attention_qkv( x_with_src_pos, encoder_output, None, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "enc_att" and encoder_output is not None: with tf.variable_scope("encdec_attention"): # TODO(llion): Add caching. y = model_helper.multihead_attention_qkv( common_layers.layer_preprocess(x, hparams), encoder_output, None, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) x = common_layers.layer_postprocess(x, y, hparams) else: tf.logging.warn( "Ignoring '%s' in decoder_layer_types" % layer_type) with tf.variable_scope("ffn"): y = transformer_ffn_layer(common_layers.layer_preprocess( x, hparams), hparams, conv_padding="LEFT", nonpadding_mask=nonpadding) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder", nonpadding=None): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This must either be passed in, which we do for "packed" datasets, or inferred from encoder_self_attention_bias. The knowledge about padding is used for pad_remover(efficiency) and to mask out padding in convoltutional layers. Returns: y: a Tensors """ x = encoder_input with tf.variable_scope(name): if nonpadding is not None: padding = 1.0 - nonpadding else: padding = common_attention.attention_bias_to_padding( encoder_self_attention_bias) nonpadding = 1.0 - padding pad_remover = None if hparams.use_pad_remover: pad_remover = expert_utils.PadRemover(padding) sequence_length = usr_utils.get_length_from_nonpadding(nonpadding) for layer in xrange(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): for layer_type in _iter_layer_types( hparams.encoder_layer_types, layer): if layer_type == "self_att": with tf.variable_scope("self_attention"): y = model_helper.multihead_attention_qkv( common_layers.layer_preprocess(x, hparams), None, None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams. encoder_self_attention_type, attention_order=hparams.attention_order, max_relative_position=hparams. max_relative_position) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "rnn": with tf.variable_scope("recurrent"): y = transformer_rnn_layer( common_layers.layer_preprocess(x, hparams), sequence_length, hparams) x = common_layers.layer_postprocess(x, y, hparams) elif layer_type == "birnn": with tf.variable_scope("recurrent"): y = transformer_rnn_layer( common_layers.layer_preprocess(x, hparams), sequence_length, hparams, bidirectional=True) x = common_layers.layer_postprocess(x, y, hparams) else: tf.logging.warn( "Ignoring '%s' in encoder_layer_types" % layer_type) with tf.variable_scope("ffn"): y = transformer_ffn_layer(common_layers.layer_preprocess( x, hparams), hparams, pad_remover, conv_padding="SAME", nonpadding_mask=nonpadding) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_decoder(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=None, name="decoder", nonpadding=None, save_weights_to=None): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This is used to mask out padding in convoltutional layers. We generally only need this mask for "packed" datasets, because for ordinary datasets, no padding is ever followed by nonpadding. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). Returns: y: a Tensors """ x = decoder_input with tf.variable_scope(name): for layer in xrange(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, max_relative_position=hparams.max_relative_position, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): # TODO(llion): Add caching. y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, save_weights_to=save_weights_to) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer(common_layers.layer_preprocess( x, hparams), hparams, conv_padding="LEFT", nonpadding_mask=nonpadding) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder", nonpadding=None, save_weights_to=None): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This must either be passed in, which we do for "packed" datasets, or inferred from encoder_self_attention_bias. The knowledge about padding is used for pad_remover(efficiency) and to mask out padding in convoltutional layers. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). Returns: y: a Tensors """ x = encoder_input with tf.variable_scope(name): if nonpadding is not None: padding = 1.0 - nonpadding else: padding = common_attention.attention_bias_to_padding( encoder_self_attention_bias) nonpadding = 1.0 - padding pad_remover = None if hparams.use_pad_remover: pad_remover = expert_utils.PadRemover(padding) for layer in xrange(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, max_relative_position=hparams.max_relative_position) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer(common_layers.layer_preprocess( x, hparams), hparams, pad_remover, conv_padding="SAME", nonpadding_mask=nonpadding) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_layers_sharded(dp, ps_devices, inputs, num_layers, hparams, self_attention_bias=None, enc_output=None, attention_type=AttentionType.GLOBAL, name="transformer"): """Multi layer transformer, sharded by the data parallelism dp.""" x = inputs extra_loss = tf.constant(0.0) moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) for layer in range(num_layers): with tf.variable_scope("%s_layer_%d" % (name, layer)): # self-attention if attention_type == AttentionType.LOCAL_2D: y = dp( local_attention_2d( common_layers.layer_preprocess(x, hparams), hparams, attention_type="masked_local_attention_2d")) elif attention_type == AttentionType.LOCAL_1D: y = dp( local_attention_1d(common_layers.layer_preprocess( x, hparams), hparams, attention_type="local_mask_right", q_padding="LEFT", kv_padding="LEFT")) elif attention_type == AttentionType.GLOCAL: y = dp( local_global_attention(common_layers.layer_preprocess( x, hparams), self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT")) elif attention_type == AttentionType.GLOBAL: self_attention_bias = dp(get_self_attention_bias(x)) y = dp( full_self_attention(common_layers.layer_preprocess( x, hparams), self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT")) x = common_layers.layer_postprocess(x, y, hparams) if enc_output is not None: y = dp( encdec_attention_1d( common_layers.layer_preprocess(x, hparams), enc_output, None, hparams)) x = dp(common_layers.layer_postprocess, x, y, hparams) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers_decoder.split(","): y, loss = expert_utils.distributed_moe( dp, ps_devices, common_layers.layer_preprocess(x, hparams), hparams.mode == tf.estimator.ModeKeys.TRAIN, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_loss += loss x = dp(common_layers.layer_postprocess, x, y, hparams) else: y = dp(ffn_layer, common_layers.layer_preprocess(x, hparams), hparams) x = dp(common_layers.layer_postprocess, x, y, hparams) return dp(common_layers.layer_preprocess, x, hparams), extra_loss
def transformer_decoder_layers(inputs, encoder_output, num_layers, hparams, self_attention_bias=None, encoder_decoder_attention_bias=None, attention_type=AttentionType.LOCAL_2D, name="transformer"): """Multi layer transformer.""" x = inputs x = tf.nn.dropout(x, 1.0 - hparams.layer_prepostprocess_dropout) if attention_type == AttentionType.DILATED: assert len(hparams.gap_sizes) == num_layers for layer in range(num_layers): with tf.variable_scope("%s_layer_%d" % (name, layer)): # self-attention + skip connections if attention_type == AttentionType.LOCAL_2D: y = local_attention_2d( common_layers.layer_preprocess(x, hparams), hparams, attention_type="masked_local_attention_2d") elif attention_type == AttentionType.LOCAL_1D: y = local_attention_1d(common_layers.layer_preprocess( x, hparams), hparams, attention_type="local_mask_right", q_padding="LEFT", kv_padding="LEFT") elif attention_type == AttentionType.NON_CAUSAL_1D: y = local_attention_1d(common_layers.layer_preprocess( x, hparams), hparams, attention_type="local_unmasked", q_padding="VALID", kv_padding="VALID") elif attention_type == AttentionType.LOCAL_BLOCK: y = local_within_block_attention( common_layers.layer_preprocess(x, hparams), self_attention_bias, hparams, attention_type="local_within_block_mask_right", q_padding="LEFT", kv_padding="LEFT") elif attention_type == AttentionType.GLOCAL: y = local_global_attention(common_layers.layer_preprocess( x, hparams), self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT") elif attention_type == AttentionType.DILATED: y = dilated_attention_1d(common_layers.layer_preprocess( x, hparams), hparams, q_padding="LEFT", kv_padding="LEFT", gap_size=hparams.gap_sizes[layer]) elif attention_type == AttentionType.GLOBAL: y = full_self_attention(common_layers.layer_preprocess( x, hparams), self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT") x = common_layers.layer_postprocess(x, y, hparams) # enc-dec attention + skip connections if encoder_output is not None: y = encdec_attention_1d( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams) x = common_layers.layer_postprocess(x, y, hparams) # feed-fwd layers + skip connections y = ffn_layer(common_layers.layer_preprocess(x, hparams), hparams) x = common_layers.layer_postprocess(x, y, hparams) return common_layers.layer_preprocess(x, hparams)
def nas_decoder(decoder_input, encoder_cell_outputs, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, final_layer_norm=True): """Decoder for configurable model. Args: decoder_input: Input tensor. encoder_cell_outputs: List of tensors. The encoder cell outputs, listed in order. decoder_self_attention_bias: Attention bias that the decoder uses when attending to itself. This should have 0s for all valid positions and large negative numbers for all hidden future positions. encoder_decoder_attention_bias: Attention bias that the decoder uses when attending to the encoder. This should be 0s at all valid positions and large negative numbers for all padded positions. hparams: transformer.Transformer hparams that must also contain: + decoder_<left|right>_inputs: List of ints specifying the hidden layer input indexes for the <left|right> branches. + decoder_<left|right>_layers: String list of layers. Each string must be the name of a TranslationLayer registered in layers.py's DECODER_LAYERS. + decoder_<left|right>_activations: String list of activations. Each string in this list must have a corresponding activation in ACTIVATION_MAP. + decoder_<left|right>_output_dims: Int list of output dimensions for <left|right> branch layers. + decoder_<left|right>_norms: String list of norms to apply to the <left|right> layer branches. Each item must be either LAYER_NORM_KEY or NO_NORM_KEY. + decoder_num_cells: The number of cells in the decoder. This determines how many times the given layers will be repeated. + decoder_combiner_functions: String list of functions used to combine left and right branches. Must be a COMBINER_FUNCTION key. hparams may also optionally contain: + enforce_output_size: Boolean that determines whether or not the decoder output must be resized to hparams.hidden_size. If True, the output will be resized if it not equal to hparams.hidden_size. If False, the output will not be resized. If this field is not set, behavior defaults to True. final_layer_norm: Whether or not to apply a final layer norm to the output of the decoder. Returns: Decoder output tensor. """ # Enforce that the output tensor depth is equal to the depth of the encoding. (_, output_depth, _, _) = calculate_branching_model_parameters( encoding_depth=hparams.hidden_size, left_inputs=hparams.decoder_left_inputs, left_layers=hparams.decoder_left_layers, left_output_dims=hparams.decoder_left_output_dims, right_inputs=hparams.decoder_right_inputs, right_layers=hparams.decoder_right_layers, right_output_dims=hparams.decoder_right_output_dims, combiner_functions=hparams.decoder_combiner_functions, final_combiner_function=hparams.decoder_final_combiner_function, layer_registry=layers.DECODER_LAYERS, num_cells=hparams.decoder_num_cells, encoder_depth=hparams.hidden_size) improper_output_size = output_depth != hparams.hidden_size try: enforce_output_size = hparams.enforce_output_size except AttributeError: enforce_output_size = True resize_output = enforce_output_size and improper_output_size decoder_cells_output, _ = apply_nas_layers( input_tensor=decoder_input, left_inputs=hparams.decoder_left_inputs, left_layers=hparams.decoder_left_layers, left_activations=hparams.decoder_left_activations, left_output_dims=hparams.decoder_left_output_dims, left_norms=hparams.decoder_left_norms, right_inputs=hparams.decoder_right_inputs, right_layers=hparams.decoder_right_layers, right_activations=hparams.decoder_right_activations, right_output_dims=hparams.decoder_right_output_dims, right_norms=hparams.decoder_right_norms, num_cells=hparams.decoder_num_cells, combiner_functions=hparams.decoder_combiner_functions, final_combiner_function=hparams.decoder_final_combiner_function, nonpadding=None, layer_registry=layers.DECODER_LAYERS, mask_future=True, hparams=hparams, var_scope="decoder", decoder_self_attention_bias=decoder_self_attention_bias, encoder_decoder_attention_bias=encoder_decoder_attention_bias, encoder_cell_outputs=encoder_cell_outputs, final_layer_norm=final_layer_norm) if not resize_output: return decoder_cells_output # Resize output if necessary. dense_layer = layers.DECODER_LAYERS.get(layers.STANDARD_CONV_1X1_REGISTRY_KEY) output = dense_layer.apply_layer( decoder_cells_output, None, hparams.hidden_size, None, hparams, "decoder_resize_dense", mask_future=True, layer_preprocess_fn=None, postprocess_dropout=True, nonpadding=None, attention_dropout_broadcast_dims=None, encoder_decoder_attention_bias=None, encoder_cell_outputs=None, decoder_self_attention_bias=None, ) if final_layer_norm: output = common_layers.layer_preprocess(output, hparams) return output
def transformer_decoder(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=None, name="decoder", nonpadding=None, save_weights_to=None, make_image_summary=True): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This is used to mask out padding in convoltutional layers. We generally only need this mask for "packed" datasets, because for ordinary datasets, no padding is ever followed by nonpadding. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. Returns: y: a Tensors """ x = decoder_input attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name): for layer in xrange(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, max_relative_position=hparams.max_relative_position, cache=layer_cache, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims) x = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): # TODO(llion): Add caching. y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, conv_padding="LEFT", nonpadding_mask=nonpadding) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder", nonpadding=None, save_weights_to=None, make_image_summary=True, losses=None): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This must either be passed in, which we do for "packed" datasets, or inferred from encoder_self_attention_bias. The knowledge about padding is used for pad_remover(efficiency) and to mask out padding in convolutional layers. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. losses: optional list onto which to append extra training losses Returns: y: a Tensors """ x = encoder_input attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_NUM_HIDDEN_LAYERS, value=hparams.num_encoder_layers or hparams.num_hidden_layers) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_ATTENTION_DROPOUT, value=hparams.attention_dropout) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_ATTENTION_DENSE, value={ "use_bias": "false", "num_heads": hparams.num_heads, "hidden_size": hparams.hidden_size }) with tf.variable_scope(name): if nonpadding is not None: padding = 1.0 - nonpadding else: padding = common_attention.attention_bias_to_padding( encoder_self_attention_bias) nonpadding = 1.0 - padding pad_remover = None if hparams.use_pad_remover and not common_layers.is_xla_compiled(): pad_remover = expert_utils.PadRemover(padding) for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams.add_relative_to_values, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d")) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, pad_remover, conv_padding="SAME", nonpadding_mask=nonpadding, losses=losses) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_NORM, value={"hidden_size": hparams.hidden_size}) return common_layers.layer_preprocess(x, hparams)
def transformer_revnet_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder"): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string Returns: y: a Tensors """ def f(x, side_input): """f(x) for reversible layer, self-attention layer.""" encoder_self_attention_bias = side_input[0] old_hid_size = hparams.hidden_size hparams.hidden_size = old_hid_size // 2 with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess( x, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) y = common_layers.layer_postprocess(x, y, hparams) hparams.hidden_size = old_hid_size return y def g(x): """g(x) for reversible layer, feed-forward layer.""" old_hid_size = hparams.hidden_size hparams.hidden_size = old_hid_size // 2 with tf.variable_scope("ffn"): y = transformer.transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) y = common_layers.layer_postprocess(x, y, hparams) hparams.hidden_size = old_hid_size return y x1, x2 = tf.split(encoder_input, 2, axis=-1) with tf.variable_scope(name): y1, y2 = rev_block.rev_block( x1, x2, f, g, num_layers=hparams.num_hidden_layers, f_side_input=[encoder_self_attention_bias], is_training=hparams.mode == tf.estimator.ModeKeys.TRAIN) y = tf.concat([y1, y2], axis=-1) return common_layers.layer_preprocess(y, hparams)
def attn_over_sent_and_lex_2d_dec(x, encoder_output, decoder_self_attention_bias, hparams): with tf.variable_scope("self_attention"): query_antecedent = common_layers.layer_preprocess(x, hparams) y = common_attention.multihead_attention( query_antecedent=query_antecedent, memory_antecedent=None, bias=decoder_self_attention_bias, total_key_depth=hparams.attention_key_channels or hparams.hidden_size, total_value_depth=hparams.attention_value_channels or hparams.hidden_size, output_depth=hparams.hidden_size, num_heads=hparams.num_heads, dropout_rate=hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position) x = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): query_antecedent = common_layers.layer_preprocess(x, hparams) batch_size = tf.shape(encoder_output)[0] src_len = tf.shape(encoder_output)[1] tgt_len = tf.shape(query_antecedent)[1] lex_cap = encoder_output.shape.as_list()[2] hid_size = encoder_output.shape.as_list()[3] query_antecedent = tf.expand_dims(query_antecedent, 2) query_antecedent = tf.pad( query_antecedent, [[0, 0], [0, 0], [0, lex_cap - 1], [0, 0]]) query_pad = tf.zeros([batch_size, src_len, lex_cap, hid_size]) query_antecedent = tf.concat([query_antecedent, query_pad], 1) memory_antecedent = encoder_output memory_pad = tf.zeros([batch_size, tgt_len, lex_cap, hid_size]) memory_antecedent = tf.concat([memory_antecedent, memory_pad], 1) tf.logging.info( "dimension of decoder input at the enc-dec attention layer: {0}" .format(query_antecedent.get_shape())) tf.logging.info( "dimension of encoder output at the enc-dec attention layer: {0}" .format(memory_antecedent.get_shape())) y = common_attention.multihead_attention_2d( query_antecedent=query_antecedent, memory_antecedent=memory_antecedent, total_key_depth=hparams.attention_key_channels or hparams.hidden_size, total_value_depth=hparams.attention_value_channels or hparams.hidden_size, output_depth=hparams.hidden_size, num_heads=hparams.num_heads, attention_type="masked_local_attention_2d", query_shape=(4, 4), memory_flange=(4, 4)) tf.logging.info("dimension of enc-dec output: {0}".format( y.get_shape())) y = y[:, :, 0, :] y = y[:, :tgt_len, :] x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): x0 = common_layers.layer_preprocess(x, hparams) y = transformer.transformer_ffn_layer(x0, hparams) x = common_layers.layer_postprocess(x, y, hparams) return x
def transformer_dual_decoder(decoder_input, wav_encoder_output, txt_encoder_output, decoder_self_attention_bias, wav_enc_dec_attention_bias, txt_enc_dec_attention_bias, hparams, cache=None, name="dual_decoder", nonpadding=None, save_weights_to=None, make_image_summary=True, losses=None): """A stack of transformer layers. decoder with two attentive interaction with each encoder Args: decoder_input: a Tensor wav_encoder_output: a Tensor txt_encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) wav_enc_dec_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) txt_enc_dec_attention_bias: the same as former hparams: hyperparameters for model cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This is used to mask out padding in convolutional layers. We generally only need this mask for "packed" datasets, because for ordinary datasets, no padding is ever followed by nonpadding. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. losses: optional list onto which to append extra training losses Returns: y: a Tensors """ x = decoder_input x1 = None x2 = None attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name): for layer in range(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): with tf.variable_scope("self_attention"): # decoder self-attention y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, max_relative_position=hparams.max_relative_position, cache=layer_cache, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_target_seq_length")) # x = common_layers.layer_postprocess(x, y, hparams) if wav_encoder_output is not None: with tf.variable_scope("wav_encdec_attention"): y1 = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), wav_encoder_output, wav_enc_dec_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.wav_num_heads or hparams.num_heads, hparams.attention_dropout, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length") ) #"max_wav_seq_length")*80 x1 = common_layers.layer_postprocess(x, y1, hparams) if txt_encoder_output is not None: with tf.variable_scope("txt_encdec_attention"): y2 = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), txt_encoder_output, txt_enc_dec_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.txt_num_heads or hparams.num_heads, hparams.attention_dropout, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_txt_seq_length") )#max_txt_seq_length x2 = common_layers.layer_postprocess(x, y2, hparams) with tf.variable_scope("ffn"): if wav_encoder_output is not None and txt_encoder_output is not None: # with two encoder to attend to y = transformer_ffn_layer( common_layers.layer_preprocess(tf.concat([x1,x2],axis=-1), hparams), hparams, conv_padding="LEFT", nonpadding_mask=nonpadding, losses=losses, cache=layer_cache) else: y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, conv_padding="LEFT", nonpadding_mask=nonpadding, losses=losses, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder", nonpadding=None, save_weights_to=None, make_image_summary=True): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This must either be passed in, which we do for "packed" datasets, or inferred from encoder_self_attention_bias. The knowledge about padding is used for pad_remover(efficiency) and to mask out padding in convolutional layers. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. Returns: y: a Tensors """ x = encoder_input attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) mlperf_log.transformer_print(key=mlperf_log.MODEL_HP_NUM_HIDDEN_LAYERS, value=hparams.num_encoder_layers or hparams.num_hidden_layers) mlperf_log.transformer_print(key=mlperf_log.MODEL_HP_ATTENTION_DROPOUT, value=hparams.attention_dropout) mlperf_log.transformer_print(key=mlperf_log.MODEL_HP_ATTENTION_DENSE, value={ "use_bias": "false", "num_heads": hparams.num_heads, "hidden_size": hparams.hidden_size }) with tf.variable_scope(name): if nonpadding is not None: padding = 1.0 - nonpadding else: padding = common_attention.attention_bias_to_padding( encoder_self_attention_bias) nonpadding = 1.0 - padding pad_remover = None if hparams.use_pad_remover and not common_layers.is_xla_compiled(): pad_remover = expert_utils.PadRemover(padding) for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers): initial_sparsity = None if hparams.get("load_masks_from"): initial_sparsity = hparams.get("initial_sparsity") with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = sparse_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams.add_relative_to_values, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d"), sparsity_technique=hparams.get("sparsity_technique"), threshold=hparams.get("log_alpha_threshold"), training=hparams.get( "mode") == tf_estimator.ModeKeys.TRAIN, clip_alpha=hparams.get("clip_log_alpha"), initial_sparsity=initial_sparsity, split_heads=hparams.get("split_heads")) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, pad_remover) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_NORM, value={"hidden_size": hparams.hidden_size}) return common_layers.layer_preprocess(x, hparams)
def hierarchical_context_encoder(encoder_input, encoder_self_attention_bias, contexts, context_self_attention_biases, features, hparams, name="discourse_aware_encoder", save_weights_to=None, make_image_summary=True, losses=None): input_x = encoder_input context_xs = {} for context_name in contexts: context_xs[context_name] = contexts[context_name] context_paddings = {} context_nonpaddings = {} context_pad_removers = {} attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name, reuse=tf.AUTO_REUSE): input_padding = common_attention.attention_bias_to_padding( encoder_self_attention_bias) input_nonpadding = 1.0 - input_padding for context_name in context_self_attention_biases: context_paddings[ context_name] = common_attention.attention_bias_to_padding( context_self_attention_biases[context_name]) context_nonpaddings[ context_name] = 1.0 - context_paddings[context_name] input_pad_remover = None for context_name in context_paddings: context_pad_removers[context_name] = None if hparams.use_pad_remover and not common_layers.is_xla_compiled(): input_pad_remover = expert_utils.PadRemover(input_padding) for context_name in context_paddings: context_pad_removers[context_name] = expert_utils.PadRemover( context_paddings[context_name]) temp_hparam = tf.contrib.training.HParams( ) # copy hparams except num_hidden_layers -> num_hidden_layers - 1 for key, val in hparams.values().items(): temp_hparam.add_hparam(key, val) temp_hparam.set_hparam("num_hidden_layers", hparams.num_hidden_layers - 1) encoder_output = transformer_with_contexts_layers.transformer_encoder( input_x, encoder_self_attention_bias, temp_hparam, nonpadding=features_to_nonpadding(features, "inputs"), save_weights_to=save_weights_to, make_image_summary=make_image_summary) context_encoded_outputs = {} for context_name in context_xs: context_encoded_outputs[ context_name] = transformer_with_contexts_layers.transformer_encoder( context_xs[context_name], context_self_attention_biases[context_name], temp_hparam, nonpadding=features_to_nonpadding(features, context_name), save_weights_to=save_weights_to, make_image_summary=make_image_summary) with tf.variable_scope("hierarchical_context_encoder", reuse=tf.AUTO_REUSE): for context_name in context_encoded_outputs: # self attention feed-forward _y = ffn_self_attention_layer( context_encoded_outputs[context_name], hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, save_weights_to=save_weights_to, name="attentive_sum") # mean over sequence length context_encoded_outputs[context_name] = tf.reduce_mean( _y, axis=1, keep_dims=True) encoded_contexts = [ context_encoded_outputs[context_name] for context_name in context_encoded_outputs ] encoded_contexts = tf.concat(encoded_contexts, axis=1) temp_hparam = tf.contrib.training.HParams( ) # copy hparams except num_hidden_layers -> 1 for key, val in hparams.values().items(): temp_hparam.add_hparam(key, val) temp_hparam.set_hparam("num_hidden_layers", 1) context_padding = common_attention.embedding_to_padding( encoded_contexts) ignore_padding = common_attention.attention_bias_ignore_padding( context_padding) encoded_contexts = transformer_encoder(encoded_contexts, ignore_padding, temp_hparam) with tf.variable_scope("encoder/layer_%d" % hparams.num_hidden_layers, reuse=tf.AUTO_REUSE): with tf.variable_scope("context_input_attention"): context_padding = common_attention.embedding_to_padding( encoded_contexts) ignore_padding = common_attention.attention_bias_ignore_padding( context_padding) _y = common_attention.multihead_attention( common_layers.layer_preprocess(encoder_output, hparams), encoded_contexts, ignore_padding, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, make_image_summary=make_image_summary, max_relative_position=hparams.max_relative_position, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d")) encoded_contexts = common_layers.layer_postprocess( encoder_output, _y, hparams) with tf.variable_scope("input_self_attention"): _y = common_attention.multihead_attention( common_layers.layer_preprocess(encoder_output, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, max_relative_position=hparams.max_relative_position, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d")) encoder_output = common_layers.layer_postprocess( encoder_output, _y, hparams) with tf.variable_scope("gated_sum"): _depth = common_layers.shape_list(encoder_output)[-1] gate = tf.layers.dense(tf.concat( [encoded_contexts, encoder_output], axis=-1), _depth, activation=tf.nn.sigmoid) if save_weights_to: save_weights_to["gated_sum"] = gate encoder_output = gate * encoder_output + ( 1. - gate) * encoded_contexts with tf.variable_scope("ffn"): _y = transformer_ffn_layer(common_layers.layer_preprocess( encoder_output, hparams), hparams, input_pad_remover, conv_padding="SAME", nonpadding_mask=input_nonpadding, losses=losses) encoder_output = common_layers.layer_postprocess( encoder_output, _y, hparams) return common_layers.layer_preprocess(encoder_output, hparams)
def decoder( decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, name="decoder", save_weights_to=None, make_image_summary=True, ): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. Returns: y: a Tensors """ x = decoder_input with tf.variable_scope(name): for layer in range(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer with tf.variable_scope(layer_name): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, make_image_summary=make_image_summary, ) utils.collect_named_outputs( "norms", "decoder_self_attention_%d" % (layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs( "norms", "decoder_self_attention_post_%d" % (layer), tf.norm(x, axis=-1)) if encoder_output is not None: with tf.variable_scope("encdec_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, save_weights_to=save_weights_to, make_image_summary=make_image_summary, ) utils.collect_named_outputs( "norms", "decoder_encoder_attention_%d" % (layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs( "norms", "decoder_encoder_attention_post_%d" % (layer), tf.norm(x, axis=-1)) with tf.variable_scope("ffn"): y = common_layers.dense_relu_dense( common_layers.layer_preprocess(x, hparams), hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout, ) utils.collect_named_outputs("norms", "decoder_ffn_%d" % (layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs( "norms", "decoder_ffn_post_%d" % (layer), tf.norm(x, axis=-1)) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def hierarchical_attention_network_encoder( encoder_input, encoder_self_attention_bias, contexts, context_self_attention_biases, features, hparams, name="hierarchical_attention_network_encoder", save_weights_to=None, make_image_summary=True, losses=None): input_x = encoder_input context_xs = {} for context_name in contexts: context_xs[context_name] = contexts[context_name] context_paddings = {} context_nonpaddings = {} context_pad_removers = {} attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name, reuse=tf.AUTO_REUSE): input_padding = common_attention.attention_bias_to_padding( encoder_self_attention_bias) input_nonpadding = 1.0 - input_padding for context_name in context_self_attention_biases: context_paddings[ context_name] = common_attention.attention_bias_to_padding( context_self_attention_biases[context_name]) context_nonpaddings[ context_name] = 1.0 - context_paddings[context_name] input_pad_remover = None for context_name in context_paddings: context_pad_removers[context_name] = None if hparams.use_pad_remover and not common_layers.is_xla_compiled(): input_pad_remover = expert_utils.PadRemover(input_padding) for context_name in context_paddings: context_pad_removers[context_name] = expert_utils.PadRemover( context_paddings[context_name]) temp_hparam = tf.contrib.training.HParams( ) # copy hparams except num_hidden_layers -> num_hidden_layers - 1 for key, val in hparams.values().items(): temp_hparam.add_hparam(key, val) temp_hparam.set_hparam("num_hidden_layers", hparams.num_hidden_layers - 1) encoder_output = transformer_with_contexts_layers.transformer_encoder( input_x, encoder_self_attention_bias, temp_hparam, nonpadding=features_to_nonpadding(features, "inputs"), save_weights_to=save_weights_to, make_image_summary=make_image_summary) context_encoded_outputs = {} for context_name in context_xs: context_encoded_outputs[ context_name] = transformer_with_contexts_layers.transformer_encoder( context_xs[context_name], context_self_attention_biases[context_name], hparams, nonpadding=features_to_nonpadding(features, context_name), save_weights_to=save_weights_to, make_image_summary=make_image_summary) with tf.variable_scope('word_abstraction', reuse=tf.AUTO_REUSE): encoder_word_level_query = common_layers.dense( encoder_output, hparams.hidden_size) # q_w = f_w(h_t) encoder_word_level_abstraction = {} for context_name in context_encoded_outputs: encoder_word_level_abstraction[ context_name] = transformer_with_contexts_layers.multihead_attention( common_layers.layer_preprocess( encoder_word_level_query, hparams), context_encoded_outputs[context_name], context_self_attention_biases[context_name], hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, make_image_summary=make_image_summary, max_relative_position=hparams.max_relative_position, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d")) # s^j, sentence_information = tf.concat([ encoder_word_level_abstraction[context_name] for context_name in encoder_word_level_abstraction ], axis=1) with tf.variable_scope('sentence_abstraction', reuse=tf.AUTO_REUSE): encoder_sentence_level_query = common_layers.dense( encoder_output, hparams.hidden_size) # q_s = f_s(h_t) context_padding = common_attention.embedding_to_padding( sentence_information) ignore_padding = common_attention.attention_bias_ignore_padding( context_padding) contextual_information = transformer_with_contexts_layers.multihead_attention( common_layers.layer_preprocess(encoder_sentence_level_query, hparams), sentence_information, ignore_padding, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, make_image_summary=make_image_summary, max_relative_position=hparams.max_relative_position, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d") ) # MultiHead(q_s, s^j), [batch, encoder_length, hidden_dim] contextual_information = common_layers.dense_relu_dense( contextual_information, hparams.filter_size, hparams.hidden_size) with tf.variable_scope('context_gating', reuse=tf.AUTO_REUSE): gate_lambda = tf.nn.sigmoid( common_layers.dense(contextual_information, hparams.hidden_size) + common_layers.dense(encoder_output, hparams.hidden_size)) encoder_output = gate_lambda * encoder_output + ( 1 - gate_lambda) * contextual_information return common_layers.layer_preprocess(encoder_output, hparams)
def encode_lex(self, encoder_input, target_space, hparams): ''' encoder_input: [batch_size, input_len, hidden_dim] return: encoder_output: [batch_size, input_len, hidden_dim] encoder_decoder_attention_bias: [batch_size, input_len] ''' encoder_output_slices = [] for i in range(encoder_input.get_shape()[2].value): encoder_input_slice = encoder_input[:, :, i, :] # bias encoder_padding = common_attention.embedding_to_padding( encoder_input_slice) print(encoder_padding.shape.as_list() ) # ==> [None, None] (None, None, 4) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding print(ignore_padding.shape.as_list() ) # ==> [None, 1, 1, None] (None, 1, 1, None, 4) # add target space to encoder input? ishape_static = encoder_input_slice.shape.as_list() print(ishape_static) # ==> [None, None, 300] (None, None, 4, 300) emb_target_space = common_layers.embedding( target_space, 32, ishape_static[-1], name="target_space_embedding") print(emb_target_space.shape.as_list()) # ==> [300] emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) print(emb_target_space.shape.as_list()) # ==> [1, 1, 300] encoder_input_slice += emb_target_space print(encoder_input_slice.shape.as_list() ) # ==> [None, None, 300] (None, None, 4, 300) # add timing signals to encoder input if hparams.pos == "timing": encoder_input_slice = common_attention.add_timing_signal_1d( encoder_input_slice) # dropout encoder_input_slice = tf.nn.dropout( encoder_input_slice, 1.0 - hparams.layer_prepostprocess_dropout) # encoder ''' multihead_attention( query_antecedent: [batch, length_q, channels], -- x, x memory_antecedent: [batch, length_m, channels], -- None, encoder_output bias: bias tensor, -- encoder_self_attention_bias total_key_depth: int, -- hparams.attention_key_channels or hparams.hidden_size total_value_depth: int, -- hparams.attention_value_channels or hparams.hidden_size output_depth: integer, -- hparams.hidden_size num_heads: integer dividing total_key_depth and total_value_depth, -- hparams.num_heads (8) dropout_rate: float, -- hparams.attention_dropout ... cache=None: dict, containing tensors which are the results of previous attentions used for fast decoding, {'k': [batch_size, 0, key_channels], 'v': [batch_size, 0, value_channels], used in decoder self-attention) ''' x = encoder_input_slice with tf.variable_scope("encoder" + str(i)): # remove pad pad_remover = None if hparams.use_pad_remover: pad_remover = expert_utils.PadRemover( common_attention.attention_bias_to_padding( encoder_self_attention_bias)) # self-attention along the sentence dimension for layer in xrange(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): query_antecedent = common_layers.layer_preprocess( x, hparams) y = common_attention.multihead_attention( query_antecedent=query_antecedent, memory_antecedent=None, bias=encoder_self_attention_bias, total_key_depth=hparams.attention_key_channels or hparams.hidden_size, total_value_depth=hparams. attention_value_channels or hparams.hidden_size, output_depth=hparams.hidden_size, num_heads=hparams.num_heads, dropout_rate=hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams. max_relative_position) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer.transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, pad_remover) x = common_layers.layer_postprocess(x, y, hparams) encoder_output_slice = common_layers.layer_preprocess( x, hparams) print(encoder_output_slice.shape.as_list() ) # ==> [None, None, 300] (None, None, 4, 300) encoder_output_slices.append(encoder_output_slice) encoder_output = tf.stack(encoder_output_slices, 2) print(encoder_output.shape.as_list()) # ==> [None, None, 4, 300] # -------- encoder_output_slices = [] #hparams2 = copy.deepcopy(hparams) #hparams2.hidden_size = hparams.lex_cap num_heads = int(hparams.lex_cap / 2) hparams2 = tf.contrib.training.HParams( layer_preprocess_sequence=hparams.layer_preprocess_sequence, layer_postprocess_sequence=hparams.layer_postprocess_sequence, layer_prepostprocess_dropout=hparams.layer_prepostprocess_dropout, norm_type=hparams.norm_type, hidden_size=hparams.lex_cap, norm_epsilon=hparams.norm_epsilon, ffn_layer=hparams.ffn_layer, filter_size=hparams.filter_size, relu_dropout=hparams.relu_dropout, num_heads=num_heads, attention_dropout=hparams.attention_dropout, parameter_attention_key_channels=hparams. parameter_attention_key_channels, parameter_attention_value_channels=hparams. parameter_attention_value_channels) for i in range(encoder_output.get_shape()[3].value): encoder_input_slice = encoder_output[:, :, :, i] #print(encoder_input_slice.shape.as_list()) # ==> [None, None, 4] encoder_padding = common_attention.embedding_to_padding( encoder_input_slice) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding #print(encoder_self_attention_bias.shape.as_list()) # ==> [None, 1, 1, None] # encoder ''' multihead_attention( query_antecedent: [batch, length_q, channels], -- x, x memory_antecedent: [batch, length_m, channels], -- None, encoder_output bias: bias tensor, -- encoder_self_attention_bias total_key_depth: int, -- hparams.attention_key_channels or hparams.hidden_size total_value_depth: int, -- hparams.attention_value_channels or hparams.hidden_size output_depth: integer, -- hparams.hidden_size num_heads: integer dividing total_key_depth and total_value_depth, -- hparams.num_heads (8) dropout_rate: float, -- hparams.attention_dropout ... cache=None: dict, containing tensors which are the results of previous attentions used for fast decoding, {'k': [batch_size, 0, key_channels], 'v': [batch_size, 0, value_channels], used in decoder self-attention) ''' x = encoder_input_slice with tf.variable_scope("encoder_extra" + str(i)): # remove pad pad_remover = None if hparams.use_pad_remover: pad_remover = expert_utils.PadRemover( common_attention.attention_bias_to_padding( encoder_self_attention_bias)) # self-attention along the lexicon dimension with tf.variable_scope("layer_extra"): with tf.variable_scope("self_attention"): #query_antecedent = layer_preprocess2(x, hparams, hparams.lex_cap) query_antecedent = common_layers.layer_preprocess( x, hparams2) y = common_attention.multihead_attention( query_antecedent=query_antecedent, memory_antecedent=None, bias=encoder_self_attention_bias, total_key_depth=hparams.attention_key_channels or hparams.lex_cap, total_value_depth=hparams.attention_value_channels or hparams.lex_cap, output_depth=hparams.lex_cap, num_heads=num_heads, dropout_rate=hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position ) #x = layer_postprocess2(x, y, hparams, hparams.lex_cap) x = common_layers.layer_postprocess(x, y, hparams2) with tf.variable_scope("ffn"): y = transformer.transformer_ffn_layer( common_layers.layer_preprocess(x, hparams2), hparams2, pad_remover) #x = layer_postprocess2(x, y, hparams, hparams.lex_cap) x = common_layers.layer_postprocess(x, y, hparams2) #encoder_output_slice = layer_preprocess2(x, hparams, hparams.lex_cap) encoder_output_slice = common_layers.layer_preprocess( x, hparams2) #print(encoder_output_slice.shape.as_list()) # ==> [None, None, 4] (None, None, 4, 300) encoder_output_slices.append(encoder_output_slice) encoder_output = tf.stack(encoder_output_slices, 3) print(encoder_output.shape.as_list()) # ==> [None, None, 4, 300] # -------- lex_cap = encoder_output.get_shape()[2].value embed_len = encoder_output.get_shape()[3].value assert (lex_cap == hparams.lex_cap) aggregate_layer = tf.get_variable( name="Aggregate", shape=[embed_len, embed_len, lex_cap], initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1)) encoder_output = tf.tensordot(encoder_output, aggregate_layer, axes=[[2, 3], [1, 2]]) print(encoder_output.shape.as_list()) # ==> [None, None, 300] return encoder_output, encoder_decoder_attention_bias
def decoder(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, name="decoder", save_weights_to=None, make_image_summary=True,): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. Returns: y: a Tensors """ x = decoder_input with tf.variable_scope(name): for layer in range(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer with tf.variable_scope(layer_name): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, make_image_summary=make_image_summary, ) utils.collect_named_outputs("norms", "decoder_self_attention_%d"%(layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs("norms", "decoder_self_attention_post_%d"%(layer), tf.norm(x, axis=-1)) if encoder_output is not None: with tf.variable_scope("encdec_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, save_weights_to=save_weights_to, make_image_summary=make_image_summary, ) utils.collect_named_outputs( "norms", "decoder_encoder_attention_%d"%(layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs( "norms", "decoder_encoder_attention_post_%d"%(layer), tf.norm(x, axis=-1)) with tf.variable_scope("ffn"): y = common_layers.dense_relu_dense( common_layers.layer_preprocess(x, hparams), hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout, ) utils.collect_named_outputs("norms", "decoder_ffn_%d"%(layer), tf.norm(y, axis=-1)) x = common_layers.layer_postprocess(x, y, hparams) utils.collect_named_outputs("norms", "decoder_ffn_post_%d"%(layer), tf.norm(x, axis=-1)) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_decoder(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=None, decode_loop_step=None, name="decoder", nonpadding=None, save_weights_to=None, make_image_summary=True, losses=None, layer_collection=None, recurrent_memory_by_layer=None, chunk_number=None): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. decode_loop_step: An integer, step number of the decoding loop. Only used for inference on TPU. name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This is used to mask out padding in convolutional layers. We generally only need this mask for "packed" datasets, because for ordinary datasets, no padding is ever followed by nonpadding. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. losses: optional list onto which to append extra training losses layer_collection: A tensorflow_kfac.LayerCollection. Only used by the KFAC optimizer. Default is None. recurrent_memory_by_layer: Optional dict, mapping layer names to instances of transformer_memory.RecurrentMemory. Default is None. chunk_number: an optional integer Tensor with shape [batch] used to operate the recurrent_memory. Returns: y: a Tensors """ x = decoder_input if hparams.sparse_attention_mode == "sparse": # If we want to run with our actual sparse kernels, intercept # the self_attention_type and replace it with our attention fn. seqlen = common_layers.shape_list(x)[1] sparse_attention_topology = sparse_matrix.SparseTopology( "sparse_attention", [seqlen, seqlen], connector=connectors.Uniform(0.955411645)) # 0.955411659 hparams.self_attention_type = functools.partial( hparams.self_attention_type, topology=sparse_attention_topology) elif hparams.sparse_attention_mode == "masked": # If we're training with sparse attention, create the per-layer # attention bias that describes the sparsity pattern. # # NOTE: We share the same pattern across all attention heads # within a layer due to memory constraints (because we're not # actually training with sparse kernels). Per-head patterns # would likely perform better. # # NOTE: We also share the same pattern across all layers, as # protobuf can't save all of these large tensors if we create # more than one of them. decoder_self_attention_bias = generate_sparse_attention_mask( common_layers.shape_list(x)[1], hparams, 0) tf.logging.info("Generated sparse attention mask.") elif hparams.sparse_attention_mode == "dense": # Replace the dot-product attention with our memory efficient # version. hparams.self_attention_type = functools.partial( hparams.self_attention_type, bias=decoder_self_attention_bias) pass else: # For training on TPU, use T2T's standard attention. assert hparams.sparse_attention_mode is None with tf.variable_scope(name): for layer_idx in range(hparams.num_decoder_layers or hparams.num_hidden_layers): x = transformer.transformer_decoder_layer( x, decoder_self_attention_bias, layer_idx, hparams, encoder_decoder_attention_bias=encoder_decoder_attention_bias, encoder_output=encoder_output, cache=cache, decode_loop_step=decode_loop_step, nonpadding=nonpadding, save_weights_to=save_weights_to, make_image_summary=make_image_summary, losses=losses, layer_collection=layer_collection, recurrent_memory_by_layer=recurrent_memory_by_layer, chunk_number=chunk_number) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess( x, hparams, layer_collection=layer_collection)
def transformer_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder", nonpadding=None, save_weights_to=None, make_image_summary=True): """A stack of transformer layers. Args: encoder_input: a Tensor encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This must either be passed in, which we do for "packed" datasets, or inferred from encoder_self_attention_bias. The knowledge about padding is used for pad_remover(efficiency) and to mask out padding in convoltutional layers. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. Returns: y: a Tensors """ x = encoder_input attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name): if nonpadding is not None: padding = 1.0 - nonpadding else: padding = common_attention.attention_bias_to_padding( encoder_self_attention_bias) nonpadding = 1.0 - padding pad_remover = None if hparams.use_pad_remover and not common_layers.is_on_tpu(): pad_remover = expert_utils.PadRemover(padding) for layer in xrange(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, max_relative_position=hparams.max_relative_position, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, pad_remover, conv_padding="SAME", nonpadding_mask=nonpadding) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_decoder(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=None, name="decoder"): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. name: a string Returns: y: a Tensors """ x = decoder_input with tf.variable_scope(name): for layer in xrange(hparams.num_decoder_layers or hparams.num_hidden_layers): layer_name = "layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): # TODO(llion): Add caching. y = common_attention.multihead_attention( common_layers.layer_preprocess( x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it shuold also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def apply_nas_layers(input_tensor, left_inputs, left_layers, left_activations, left_output_dims, left_norms, right_inputs, right_layers, right_activations, right_output_dims, right_norms, combiner_functions, final_combiner_function, num_cells, nonpadding, layer_registry, mask_future, hparams, var_scope, encoder_decoder_attention_bias=None, encoder_cell_outputs=None, decoder_self_attention_bias=None, final_layer_norm=True, enforce_fixed_output_sizes=True): """Applies layers with NasNet search space style branching. Args: input_tensor: Input [batch_size, input_length, hidden_dim] sequence tensor. left_inputs: Int list of left branch hidden layer input indexes. left_layers: String list of left branch layers. left_activations: String list of left branch activations. left_output_dims: String list of left branch output dimensions. left_norms: String list of left branch norms. right_inputs: Int list of right branch hidden layer input indexes. right_layers: String list of right branch layers. right_activations: String list of right branch activations. right_output_dims: String list of right branch output dimensions. right_norms: String list of right branch norms. combiner_functions: String list of branch combining functions. final_combiner_function: String. The final combiner function that combines all the unused hidden layers in a cell. num_cells: The number of cells. This is the number of times the given layers will be repeated. nonpadding: Tensor with 1s at all nonpadding time step positions and 0s everywhere else. layer_registry: The LayerRegistry that holds all valid layers. mask_future: Whether or not to mask future sequence values. hparams: Hyperparameters for the model. var_scope: The variable scope name. encoder_decoder_attention_bias: The attention bias for decoder attending to `encoder_output`. encoder_cell_outputs: List of tensors. The encoder cell outputs, listed in order. decoder_self_attention_bias: The self attention bias for decoders. This needs to be set for decoders. final_layer_norm: Whether or not to apply a final layer_norm to the output of the model. enforce_fixed_output_sizes: Whether or not to automatically resize output dimensions to match the input dimension if `should_alter_output_dim()` returns True. Raises: ValueError: When branching inputs are not of the same length. ValueError: If item in left_norms is not LAYER_NORM_KEY or NO_NORM_KEY. ValueError: If item in right_norms is not LAYER_NORM_KEY or NO_NORM_KEY. Returns: Output of applied layers and list of each cell's outputs in order. """ if not (len(left_inputs) == len(left_layers) == len(left_activations) == len(left_output_dims) == len(left_norms) == len(right_inputs) == len(right_layers) == len(right_activations) == len(right_output_dims) == len(right_norms) == len(combiner_functions)): raise ValueError("All branching inputs must be of the same length.") cell_output = None modified_left_inputs = [ left_inputs[i] for i in range(len(left_inputs)) if left_layers[i] != DEAD_BRANCH_KEY ] modified_right_inputs = [ right_inputs[i] for i in range(len(right_inputs)) if right_layers[i] != DEAD_BRANCH_KEY ] unused_cell_hidden_states = [ i for i in range(len(left_inputs) + 1) if i not in modified_left_inputs and i not in modified_right_inputs ] assert unused_cell_hidden_states cell_outputs = [] with tf.variable_scope(var_scope): dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) for cell_num in range(num_cells): # h_0 is the input tensor. # Keep a dict for layer norm states. if cell_output is not None: cell_hidden_states = [cell_output] else: cell_hidden_states = [input_tensor] layer_norm_dict = {} with tf.variable_scope("cell_%d" % cell_num): for i, (left_input, left_layer_name, left_activation_name, left_output_dim, left_norm, right_input, right_layer_name, right_activation_name, right_output_dim, right_norm, combiner) in enumerate( zip(left_inputs, left_layers, left_activations, left_output_dims, left_norms, right_inputs, right_layers, right_activations, right_output_dims, right_norms, combiner_functions)): left_input = int(left_input) right_input = int(right_input) with tf.variable_scope("layer_%d" % i): assert not (left_layer_name == DEAD_BRANCH_KEY and right_layer_name == DEAD_BRANCH_KEY) if left_layer_name != DEAD_BRANCH_KEY: left_raw_input_tensor = cell_hidden_states[left_input] left_input_dim = left_raw_input_tensor.shape.as_list()[-1] if should_alter_output_dim(left_layer_name, enforce_fixed_output_sizes, left_input_dim, left_output_dim): left_output_dim = left_input_dim # First process the left branch. left_tensor = _apply_nas_branch( norm=left_norm, layer_norm_dict=layer_norm_dict, hidden_states=cell_hidden_states, nonpadding=nonpadding, hparams=hparams, input_index=left_input, layer_name=left_layer_name, activation_name=left_activation_name, layer_registry=layer_registry, output_dim=left_output_dim, branch_scope_name="left_%s" % str(i), mask_future=mask_future, dropout_broadcast_dims=dropout_broadcast_dims, encoder_decoder_attention_bias=encoder_decoder_attention_bias, encoder_cell_outputs=encoder_cell_outputs, decoder_self_attention_bias=decoder_self_attention_bias, cell_number=cell_num) if right_layer_name != DEAD_BRANCH_KEY: right_raw_input_tensor = cell_hidden_states[right_input] right_input_dim = right_raw_input_tensor.shape.as_list()[-1] if should_alter_output_dim(right_layer_name, enforce_fixed_output_sizes, right_input_dim, right_output_dim): right_output_dim = right_input_dim # Next process the right branch. right_tensor = _apply_nas_branch( norm=right_norm, layer_norm_dict=layer_norm_dict, hidden_states=cell_hidden_states, nonpadding=nonpadding, hparams=hparams, input_index=right_input, layer_name=right_layer_name, activation_name=right_activation_name, layer_registry=layer_registry, output_dim=right_output_dim, branch_scope_name="right_%s" % str(i), mask_future=mask_future, dropout_broadcast_dims=dropout_broadcast_dims, encoder_decoder_attention_bias=encoder_decoder_attention_bias, encoder_cell_outputs=encoder_cell_outputs, decoder_self_attention_bias=decoder_self_attention_bias, cell_number=cell_num) # Combine the branches. if left_layer_name == DEAD_BRANCH_KEY: hidden_tensor = right_tensor elif right_layer_name == DEAD_BRANCH_KEY: hidden_tensor = left_tensor else: hidden_tensor = COMBINER_FUNCTIONS[combiner]().combine_tensors( [left_tensor, right_tensor]) cell_hidden_states.append(hidden_tensor) states_to_combine = [ cell_hidden_states[j] for j in unused_cell_hidden_states ] cell_output = COMBINER_FUNCTIONS[final_combiner_function]( ).combine_tensors(states_to_combine) cell_outputs.append(cell_output) if final_layer_norm: final_output = common_layers.layer_preprocess(cell_output, hparams) cell_outputs = [ common_layers.layer_preprocess(cell_output, hparams) for cell_output in cell_outputs ] return final_output, cell_outputs else: return cell_output, cell_outputs
def transformer_bidirectional_joint_decoder(left_decoder_output, right_decoder_output, encoder_output, encoder_decoder_attention_bias, hparams, cache=None, decode_loop_step=None, name="decoder", nonpadding=None, save_weights_to=None, make_image_summary=True, losses=None): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. decode_loop_step: An integer, step number of the decoding loop. Only used for inference on TPU. name: a string nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This is used to mask out padding in convolutional layers. We generally only need this mask for "packed" datasets, because for ordinary datasets, no padding is ever followed by nonpadding. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. losses: optional list onto which to append extra training losses Returns: y: a Tensors """ x = left_decoder_output + right_decoder_output attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name): for layer in range(hparams.num_bidirectional_decoder_joint_layers): layer_name = "joint_layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): if encoder_output is not None: with tf.variable_scope("encdec_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, max_relative_position=hparams. max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams. add_relative_to_values, save_weights_to=save_weights_to, cache=layer_cache, make_image_summary=make_image_summary, dropout_broadcast_dims= attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d")) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, conv_padding="LEFT", nonpadding_mask=nonpadding, losses=losses, cache=layer_cache, decode_loop_step=decode_loop_step) x = common_layers.layer_postprocess(x, y, hparams) # if normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(x, hparams)
def transformer_revnet_decoder(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, name="decoder"): """A stack of transformer layers. Args: decoder_input: a Tensor encoder_output: a Tensor decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()) encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model name: a string Returns: y: a Tensors """ def f(x, side_input): """f(x) for reversible layer, self-attention and enc-dec attention.""" decoder_self_attention_bias = side_input[0] encoder_decoder_attention_bias = side_input[1] encoder_output = side_input[2] old_hid_size = hparams.hidden_size hparams.hidden_size = old_hid_size // 2 with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) y = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess(x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) y = common_layers.layer_postprocess(x, y, hparams) hparams.hidden_size = old_hid_size return y def g(x): """g(x) for reversible layer, feed-forward layer.""" old_hid_size = hparams.hidden_size hparams.hidden_size = old_hid_size // 2 with tf.variable_scope("ffn"): y = transformer.transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) y = common_layers.layer_postprocess(x, y, hparams) hparams.hidden_size = old_hid_size return y x1, x2 = tf.split(decoder_input, 2, axis=-1) with tf.variable_scope(name): y1, y2 = tf.contrib.layers.rev_block( x1, x2, f, g, num_layers=hparams.num_hidden_layers, f_side_input=[ decoder_self_attention_bias, encoder_decoder_attention_bias, encoder_output ], is_training=hparams.mode == tf.estimator.ModeKeys.TRAIN) y = tf.concat([y1, y2], axis=-1) return common_layers.layer_preprocess(y, hparams)
def evolved_transformer_decoder(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=None, decode_loop_step=None, name="decoder", nonpadding=None, save_weights_to=None, make_image_summary=True, losses=None): """Evolved Transformer decoder. See arxiv.org/abs/1901.11117 for more details. Args: decoder_input: a Tensor. encoder_output: a Tensor. decoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()). encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()). hparams: hyperparameters for model. cache: dict, containing tensors which are the results of previous layers, used for fast decoding. decode_loop_step: An integer, step number of the decoding loop. Only used for inference on TPU. name: a string. nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This is used to mask out padding in convolutional layers. We generally only need this mask for "packed" datasets, because for ordinary datasets, no padding is ever followed by nonpadding. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. losses: Not supported. Returns: Decoder output tensor. """ del losses num_trainable_top_decoder_layers = hparams.get( "num_trainable_top_decoder_layers", -1) # -1 means train all weights. if num_trainable_top_decoder_layers >= 0: encoder_output = tf.stop_gradient(encoder_output) attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name): hidden_state = decoder_input num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers for layer in range(num_layers): if num_trainable_top_decoder_layers == num_layers - layer: hidden_state = tf.stop_gradient(hidden_state) layer_name = "layer_%d" % layer layer_cache = cache[layer_name] if cache is not None else None with tf.variable_scope(layer_name): with tf.variable_scope(_SIXTEEN_HEAD_ATTENTION_NAME): residual_state = hidden_state hidden_state = common_layers.layer_preprocess(hidden_state, hparams) attention_cache = layer_cache[ _SIXTEEN_HEAD_ATTENTION_NAME] if layer_cache is not None else None left_state = common_attention.multihead_attention( hidden_state, None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, _capped_double_heads(hparams.num_heads), hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams.add_relative_to_values, save_weights_to=save_weights_to, cache=attention_cache, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), decode_loop_step=decode_loop_step, vars_3d=hparams.get("attention_variables_3d"), activation_dtype=hparams.get("activation_dtype", "float32"), weight_dtype=hparams.get("weight_dtype", "float32")) if encoder_output is not None: with tf.variable_scope(_FIRST_ATTEND_TO_ENCODER_NAME): attention_cache = ( layer_cache[_FIRST_ATTEND_TO_ENCODER_NAME] if layer_cache is not None else None) right_state = common_attention.multihead_attention( hidden_state, encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, max_relative_position=hparams.max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams.add_relative_to_values, save_weights_to=save_weights_to, cache=attention_cache, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d"), activation_dtype=hparams.get("activation_dtype", "float32"), weight_dtype=hparams.get("weight_dtype", "float32")) left_state = tf.nn.dropout(left_state, 1 - hparams.layer_prepostprocess_dropout) right_state = tf.nn.dropout( right_state, 1 - hparams.layer_prepostprocess_dropout) hidden_state = residual_state + left_state + right_state else: hidden_state = common_layers.layer_postprocess( residual_state, left_state, hparams) with tf.variable_scope(_CONV_BRANCHES_NAME): residual_state = hidden_state hidden_state = common_layers.layer_preprocess(hidden_state, hparams) if nonpadding is not None: # Mask padding from conv layers. mask = tf.tile( tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size]) hidden_state *= mask if layer_cache: if decode_loop_step is None: hidden_state = layer_cache[ _CONV_BRANCHES_FIRST_LAYER_NAME] = tf.concat( [ layer_cache[_CONV_BRANCHES_FIRST_LAYER_NAME], hidden_state ], axis=1)[:, -1 * _DECODER_LEFT_CONV_PADDING - 1:, :] left_state = hidden_state right_state = hidden_state[:, _DECODER_LEFT_CONV_PADDING - _DECODER_RIGHT_CONV_PADDING:, :] else: # Inplace update is required for inference on TPU. # Inplace_ops only supports inplace_update on the first dimension. tmp = tf.transpose( layer_cache[_CONV_BRANCHES_FIRST_LAYER_NAME], perm=[1, 0, 2]) tmp = tf.expand_dims(tmp, axis=1) tmp = inplace_ops.alias_inplace_update( tmp, decode_loop_step * tf.shape(hidden_state)[1] + _DECODER_LEFT_CONV_PADDING, tf.transpose(hidden_state, perm=[1, 0, 2])) tmp = tf.squeeze(tmp, axis=1) hidden_state = layer_cache[ _CONV_BRANCHES_FIRST_LAYER_NAME] = tf.transpose( tmp, perm=[1, 0, 2]) batch_size = hidden_state.shape.as_list()[0] left_state = tf.slice(hidden_state, [0, decode_loop_step, 0], [ batch_size, _DECODER_LEFT_CONV_PADDING + 1, hparams.hidden_size ]) right_state = tf.slice(hidden_state, [ 0, decode_loop_step + _DECODER_LEFT_CONV_PADDING - _DECODER_RIGHT_CONV_PADDING, 0 ], [ batch_size, _DECODER_RIGHT_CONV_PADDING + 1, hparams.hidden_size ]) else: # No caching. left_state = tf.pad( hidden_state, paddings=[[0, 0], [_DECODER_LEFT_CONV_PADDING, 0], [0, 0]]) right_state = tf.pad( hidden_state, paddings=[[0, 0], [_DECODER_RIGHT_CONV_PADDING, 0], [0, 0]]) left_output_dim = int(hparams.hidden_size * 2) separable_conv_11x1 = tf.layers.SeparableConv1D( left_output_dim, 11, padding="VALID", name="separable_conv11x1", activation=tf.nn.relu) left_state = separable_conv_11x1.apply(left_state) left_state = tf.nn.dropout(left_state, 1 - hparams.layer_prepostprocess_dropout) right_output_dim = int(hparams.hidden_size / 2) separable_conv_7x1_1 = tf.layers.SeparableConv1D( right_output_dim, 7, padding="VALID", name="separable_conv_7x1_1") right_state = separable_conv_7x1_1.apply(right_state) right_state = tf.nn.dropout(right_state, 1 - hparams.layer_prepostprocess_dropout) right_state = tf.pad( right_state, [[0, 0], [0, 0], [0, left_output_dim - right_output_dim]], constant_values=0) hidden_state = left_state + right_state hidden_state = common_layers.layer_preprocess(hidden_state, hparams) if nonpadding is not None: # Mask padding from conv layers. mask = tf.tile( tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size * 2]) hidden_state *= mask if layer_cache: if decode_loop_step is None: hidden_state = layer_cache[ _CONV_BRANCHES_SECOND_LAYER_NAME] = tf.concat( [ layer_cache[_CONV_BRANCHES_SECOND_LAYER_NAME], hidden_state ], axis=1)[:, -1 * _DECODER_FINAL_CONV_PADDING - 1:, :] else: # Inplace update is required for inference on TPU. # Inplace_ops only supports inplace_update on the first dimension. tmp = tf.transpose( layer_cache[_CONV_BRANCHES_SECOND_LAYER_NAME], perm=[1, 0, 2]) tmp = tf.expand_dims(tmp, axis=1) tmp = inplace_ops.alias_inplace_update( tmp, (decode_loop_step + _DECODER_FINAL_CONV_PADDING) * tf.shape(hidden_state)[1], tf.transpose(hidden_state, perm=[1, 0, 2])) tmp = tf.squeeze(tmp, axis=1) hidden_state = layer_cache[ _CONV_BRANCHES_SECOND_LAYER_NAME] = tf.transpose( tmp, perm=[1, 0, 2]) batch_size = hidden_state.shape.as_list()[0] hidden_state = tf.slice(hidden_state, [0, decode_loop_step, 0], [ batch_size, _DECODER_FINAL_CONV_PADDING + 1, hparams.hidden_size * 2 ]) else: hidden_state = tf.pad( hidden_state, paddings=[[0, 0], [_DECODER_FINAL_CONV_PADDING, 0], [0, 0]]) separable_conv_7x1_2 = tf.layers.SeparableConv1D( hparams.hidden_size, 7, padding="VALID", name="separable_conv_7x1_2") hidden_state = separable_conv_7x1_2.apply(hidden_state) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) with tf.variable_scope(_VANILLA_ATTENTION_NAME): residual_state = hidden_state hidden_state = common_layers.layer_preprocess(hidden_state, hparams) attention_cache = layer_cache[ _VANILLA_ATTENTION_NAME] if layer_cache is not None else None hidden_state = common_attention.multihead_attention( hidden_state, None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams.add_relative_to_values, save_weights_to=save_weights_to, cache=attention_cache, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), decode_loop_step=decode_loop_step, vars_3d=hparams.get("attention_variables_3d"), activation_dtype=hparams.get("activation_dtype", "float32"), weight_dtype=hparams.get("weight_dtype", "float32")) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) if encoder_output is not None: with tf.variable_scope(_SECOND_ATTEND_TO_ENCODER_NAME): residual_state = hidden_state hidden_state = common_layers.layer_preprocess(hidden_state, hparams) attention_cache = ( layer_cache[_SECOND_ATTEND_TO_ENCODER_NAME] if layer_cache is not None else None) hidden_state = common_attention.multihead_attention( hidden_state, encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, max_relative_position=hparams.max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams.add_relative_to_values, save_weights_to=save_weights_to, cache=attention_cache, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d"), activation_dtype=hparams.get("activation_dtype", "float32"), weight_dtype=hparams.get("weight_dtype", "float32")) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) with tf.variable_scope("dense_layers"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess(hidden_state, hparams) hidden_state = tf.layers.dense( hidden_state, int(hparams.hidden_size * 4), activation=tf.nn.swish) hidden_state = tf.nn.dropout(hidden_state, 1 - hparams.layer_prepostprocess_dropout) hidden_state = common_layers.layer_preprocess(hidden_state, hparams) hidden_state = tf.layers.dense(hidden_state, hparams.hidden_size) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) decoder_output = common_layers.layer_preprocess(hidden_state, hparams) if num_trainable_top_decoder_layers == 0: decoder_output = tf.stop_gradient(decoder_output) return decoder_output
def ffn(x, hparams, name): with tf.variable_scope(name): y = transformer.transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) return common_layers.layer_postprocess(x, y, hparams)
def evolved_transformer_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder", nonpadding=None, save_weights_to=None, make_image_summary=True, losses=None, attn_bias_for_padding=None): """Evolved Transformer encoder. See arxiv.org/abs/1901.11117 for more details. Note: Pad remover is not supported. Args: encoder_input: a Tensor. encoder_self_attention_bias: bias Tensor for self-attention (see common_attention.attention_bias()). hparams: hyperparameters for model. name: a string. nonpadding: optional Tensor with shape [batch_size, encoder_length] indicating what positions are not padding. This must either be passed in, which we do for "packed" datasets, or inferred from encoder_self_attention_bias. The knowledge about padding is used for pad_remover(efficiency) and to mask out padding in convolutional layers. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. losses: Not used. attn_bias_for_padding: Padded attention bias in case a unidirectional encoder is being used where future attention is masked. Returns: Tensor encoder output. """ del losses hidden_state = encoder_input attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name): if nonpadding is not None: padding = 1.0 - nonpadding else: attention_bias = encoder_self_attention_bias if attn_bias_for_padding is not None: attention_bias = attn_bias_for_padding padding = common_attention.attention_bias_to_padding(attention_bias) nonpadding = 1.0 - padding for layer in range(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("gated_linear_unit"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess(hidden_state, hparams) values = layers().Dense(hparams.hidden_size)(hidden_state) gates = layers().Dense( hparams.hidden_size, activation=tf.nn.sigmoid)(hidden_state) hidden_state = values * gates hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) with tf.variable_scope("conv_branches"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess(hidden_state, hparams) # Mask padding from conv layers. mask = tf.tile( tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size]) hidden_state *= mask left_output_dim = int(hparams.hidden_size * 4) left_state = layers().Dense( left_output_dim, activation=tf.nn.relu)(hidden_state) left_state = tf.nn.dropout(left_state, 1 - hparams.layer_prepostprocess_dropout) right_output_dim = int(hparams.hidden_size / 2) right_state = layers().Conv1D( right_output_dim, 3, padding="SAME", name="standard_conv_3x1", activation=tf.nn.relu)(hidden_state) right_state = tf.nn.dropout(right_state, 1 - hparams.layer_prepostprocess_dropout) right_state = tf.pad( right_state, [[0, 0], [0, 0], [0, left_output_dim - right_output_dim]], constant_values=0) hidden_state = left_state + right_state hidden_state = common_layers.layer_preprocess(hidden_state, hparams) # Mask padding from conv layer. mask = tf.tile(tf.expand_dims(nonpadding, 2), [1, 1, left_output_dim]) hidden_state *= mask separable_conv_9x1 = layers().SeparableConv1D( right_output_dim, 9, padding="SAME", name="separable_conv_9x1") hidden_state = separable_conv_9x1(hidden_state) hidden_state = tf.pad( hidden_state, [[0, 0], [0, 0], [0, hparams.hidden_size - right_output_dim]], constant_values=0) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) with tf.variable_scope("self_attention"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess(hidden_state, hparams) hidden_state = common_attention.multihead_attention( hidden_state, None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position, heads_share_relative_embedding=( hparams.heads_share_relative_embedding), add_relative_to_values=hparams.add_relative_to_values, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d"), activation_dtype=hparams.get("activation_dtype", "float32"), weight_dtype=hparams.get("weight_dtype", "float32")) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) with tf.variable_scope("dense_layers"): residual_state = hidden_state hidden_state = common_layers.layer_preprocess(hidden_state, hparams) hidden_state = layers().Dense( int(hparams.hidden_size * 4), activation=tf.nn.relu)(hidden_state) hidden_state = tf.nn.dropout(hidden_state, 1 - hparams.layer_prepostprocess_dropout) hidden_state = layers().Dense(hparams.hidden_size)(hidden_state) hidden_state = common_layers.layer_postprocess( residual_state, hidden_state, hparams) # If normalization is done in layer_preprocess, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. return common_layers.layer_preprocess(hidden_state, hparams)