def _add_attend_to_encoder_cache(cache, attention_name, hparams, num_layers, key_channels, value_channels, vars_3d_num_heads, scope_prefix, encoder_output): """Add attend-to-encoder layers to cache.""" for layer in range(num_layers): layer_name = "layer_%d" % layer with tf.variable_scope("%sdecoder/%s/%s/multihead_attention" % (scope_prefix, layer_name, attention_name)): k_encdec = common_attention.compute_attention_component( encoder_output, key_channels, name="k", vars_3d_num_heads=vars_3d_num_heads) k_encdec = common_attention.split_heads(k_encdec, hparams.num_heads) v_encdec = common_attention.compute_attention_component( encoder_output, value_channels, name="v", vars_3d_num_heads=vars_3d_num_heads) v_encdec = common_attention.split_heads(v_encdec, hparams.num_heads) cache[layer_name][attention_name] = { "k_encdec": k_encdec, "v_encdec": v_encdec } return cache
def transformer_pointer_prediction_layer(feature_name, encoder_output, x, encoder_decoder_attention_bias, hparams, features, loss_mask, layer_collection=None): """Layer that predicts the start or end token position. Args: feature_name: 'targets_start_token' or 'targets_end_token' encoder_output: [batch_size, input_length, hidden_size] tensor with encoder outputs x: [batch_size, target_length, 1, hidden_size] tensor with decoder outputs encoder_decoder_attention_bias: [batch_size, input_length, target_length] attention mask hparams: Hyper parameters features: Feature dictionary loss_mask: [batch_size, target_length] mask for loss computation. layer_collection: Layer collection Returns: (x, logits, loss) """ if isinstance(encoder_output, list): pointer_encoder_output = encoder_output[1] encoder_output = sum(encoder_output) else: pointer_encoder_output = encoder_output with tf.variable_scope("%s_prediction" % feature_name): x = maybe_flatten4d3d(x) encoder_decoder_attention_bias = common_layers.flatten4d3d( encoder_decoder_attention_bias) q = common_attention.compute_attention_component(x, hparams.hidden_size) k = common_attention.compute_attention_component(encoder_output, hparams.hidden_size) # Scaled dot-product attention scalar = tf.rsqrt(tf.to_float(common_layers.shape_list(q)[2])) logits = tf.matmul(q * scalar, k, transpose_b=True) logits += encoder_decoder_attention_bias labels = features["%s_raw" % feature_name] xent = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=labels) loss = tf.reduce_sum(xent * loss_mask) pointer_out = gather_2d(pointer_encoder_output, labels) y = common_layers.layer_preprocess( pointer_out, hparams, layer_collection=layer_collection) x = common_layers.layer_postprocess(x, y, hparams) return x, logits, loss
def attn(image_feat, query, hparams, name="attn"): """Attention on image feature with question as query.""" with tf.variable_scope(name, "attn", values=[image_feat, query]): attn_dim = hparams.attn_dim num_glimps = hparams.num_glimps num_channels = common_layers.shape_list(image_feat)[-1] image_feat = common_layers.flatten4d3d(image_feat) query = tf.expand_dims(query, 1) image_proj = common_attention.compute_attention_component( image_feat, attn_dim, name="image_proj") query_proj = common_attention.compute_attention_component( query, attn_dim, name="query_proj") h = tf.nn.relu(image_proj + query_proj) h_proj = common_attention.compute_attention_component(h, num_glimps, name="h_proj") p = tf.nn.softmax(h_proj, axis=1) image_ave = tf.matmul(image_feat, p, transpose_a=True) image_ave = tf.reshape(image_ave, [-1, num_channels * num_glimps]) return image_ave
def attn(image_feat, query, hparams, name="attn"): """Attention on image feature with question as query.""" with tf.variable_scope(name, "attn", values=[image_feat, query]): attn_dim = hparams.attn_dim num_glimps = hparams.num_glimps num_channels = common_layers.shape_list(image_feat)[-1] if len(common_layers.shape_list(image_feat)) == 4: image_feat = common_layers.flatten4d3d(image_feat) query = tf.expand_dims(query, 1) image_proj = common_attention.compute_attention_component( image_feat, attn_dim, name="image_proj") query_proj = common_attention.compute_attention_component( query, attn_dim, name="query_proj") h = tf.nn.relu(image_proj + query_proj) h_proj = common_attention.compute_attention_component( h, num_glimps, name="h_proj") p = tf.nn.softmax(h_proj, axis=1) image_ave = tf.matmul(image_feat, p, transpose_a=True) image_ave = tf.reshape(image_ave, [-1, num_channels*num_glimps]) return image_ave
def multihead_attention(query_antecedent, memory_antecedent, bias, total_key_depth, total_value_depth, output_depth, num_heads, dropout_rate, shared_rel=False, max_relative_position=None, image_shapes=None, attention_type="dot_product", block_length=128, block_width=128, q_filter_width=1, kv_filter_width=1, q_padding="VALID", kv_padding="VALID", cache=None, gap_size=0, num_memory_blocks=2, name="multihead_attention", save_weights_to=None, make_image_summary=True, dropout_broadcast_dims=None, max_length=None, vars_3d=False, scale_dotproduct=True, **kwargs): """Multihead scaled-dot-product attention with input/output transformations. Args: query_antecedent: a Tensor with shape [batch, length_q, channels] memory_antecedent: a Tensor with shape [batch, length_m, channels] or None bias: bias Tensor (see attention_bias()) total_key_depth: an integer total_value_depth: an integer output_depth: an integer num_heads: an integer dividing total_key_depth and total_value_depth dropout_rate: a floating point number shared_rel: boolean to share relative embeddings max_relative_position: Maximum distance between inputs to generate unique relation embeddings for. Only relevant when using "dot_product_relative" attention. image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() attention_type: a string, either "dot_product", "dot_product_relative", "local_mask_right", "local_unmasked", "masked_dilated_1d", "unmasked_dilated_1d", graph, or any attention function with the signature (query, key, value, **kwargs) block_length: an integer - relevant for "local_mask_right" block_width: an integer - relevant for "local_unmasked" q_filter_width: An integer specifying how wide you want the query to be. kv_filter_width: An integer specifying how wide you want the keys and values to be. q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding. kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID": no padding. cache: dict containing Tensors which are the results of previous attentions, used for fast decoding. Expects the dict to contrain two keys ('k' and 'v'), for the initial call the values for these keys should be empty Tensors of the appropriate shape. 'k' [batch_size, 0, key_channels] 'v' [batch_size, 0, value_channels] gap_size: Integer option for dilated attention to indicate spacing between memory blocks. num_memory_blocks: Integer option to indicate how many memory blocks to look at. name: an optional string. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. dropout_broadcast_dims: an optional list of integers less than 4 specifying in which dimensions to broadcast the dropout decisions. saves memory. max_length: an integer - needed by relative attention vars_3d: use 3-dimensional variables for input/output transformations scale_dotproduct: whether to normalize the attention product. **kwargs (dict): Parameters for the attention function Caching: WARNING: For decoder self-attention, i.e. when memory_antecedent == None, the caching assumes that the bias contains future masking. The caching works by saving all the previous key and value values so that you are able to send just the last query location to this attention function. I.e. if the cache dict is provided it assumes the query is of the shape [batch_size, 1, hidden_dim] rather than the full memory. Returns: The result of the attention transformation. The output shape is [batch_size, length_q, hidden_dim] unless the cache dict is provided in which case only the last memory position is calculated and the output shape is [batch_size, 1, hidden_dim] Optionally returns an additional loss parameters (ex: load balance loss for the experts) returned by the attention_type function. Raises: ValueError: if the key depth or value depth are not divisible by the number of attention heads. """ if total_key_depth % num_heads != 0: raise ValueError("Key depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_key_depth, num_heads)) if total_value_depth % num_heads != 0: raise ValueError("Value depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_value_depth, num_heads)) vars_3d_num_heads = num_heads if vars_3d else 0 with tf.variable_scope(name, default_name="multihead_attention", values=[query_antecedent, memory_antecedent]): if cache is None or memory_antecedent is None: q, k, v = common_attention.compute_qkv( query_antecedent, memory_antecedent, total_key_depth, total_value_depth, q_filter_width, kv_filter_width, q_padding, kv_padding, vars_3d_num_heads=vars_3d_num_heads) if cache is not None: if attention_type != "dot_product": # TODO(petershaw): Support caching when using relative position # representations, i.e. "dot_product_relative" attention. raise NotImplementedError( "Caching is not guaranteed to work with attention types other than" " dot_product.") if bias is None: raise ValueError( "Bias required for caching. See function docstring " "for details.") if memory_antecedent is not None: # Encoder-Decoder Attention Cache q = common_attention.compute_attention_component( query_antecedent, total_key_depth, q_filter_width, q_padding, "q", vars_3d_num_heads=vars_3d_num_heads) k = cache["k_encdec"] v = cache["v_encdec"] else: k = common_attention.split_heads(k, num_heads) v = common_attention.split_heads(v, num_heads) decode_loop_step = kwargs.get("decode_loop_step") if decode_loop_step is None: k = cache["k"] = tf.concat([cache["k"], k], axis=2) v = cache["v"] = tf.concat([cache["v"], v], axis=2) else: # Inplace update is required for inference on TPU. # Inplace_ops only supports inplace_update on the first dimension. # The performance of current implementation is better than updating # the tensor by adding the result of matmul(one_hot, # update_in_current_step) tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3]) tmp_k = inplace_ops.alias_inplace_update( tmp_k, decode_loop_step, tf.squeeze(k, axis=2)) k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3]) tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3]) tmp_v = inplace_ops.alias_inplace_update( tmp_v, decode_loop_step, tf.squeeze(v, axis=2)) v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3]) q = common_attention.split_heads(q, num_heads) if cache is None: k = common_attention.split_heads(k, num_heads) v = common_attention.split_heads(v, num_heads) key_depth_per_head = total_key_depth // num_heads if not vars_3d: if scale_dotproduct: q *= key_depth_per_head**-0.5 additional_returned_value = None if callable( attention_type): # Generic way to extend multihead_attention x = attention_type(q, k, v, **kwargs) if isinstance(x, tuple): x, additional_returned_value = x # Unpack elif attention_type == "dot_product": x = common_attention.dot_product_attention( q, k, v, bias, dropout_rate, image_shapes, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=dropout_broadcast_dims) elif attention_type == "dot_product_relative": x = common_attention.dot_product_attention_relative( q, k, v, bias, max_relative_position, dropout_rate, image_shapes, make_image_summary=make_image_summary) elif attention_type == "dot_product_relative_v2": x = common_attention.dot_product_self_attention_relative_v2( q, k, v, bias, max_length, dropout_rate, image_shapes, make_image_summary=make_image_summary, dropout_broadcast_dims=dropout_broadcast_dims) elif attention_type == "local_within_block_mask_right": x = common_attention.masked_within_block_local_attention_1d( q, k, v, block_length=block_length) elif attention_type == "rel_local_mask_right": x = common_attention.masked_rel_local_attention_1d( q, k, v, block_length=block_length, make_image_summary=make_image_summary, dropout_rate=dropout_rate, share_rel_embed=shared_rel) elif attention_type == "local_mask_right": x = common_attention.masked_local_attention_1d( q, k, v, block_length=block_length, make_image_summary=make_image_summary) elif attention_type == "local_unmasked": x = common_attention.local_attention_1d(q, k, v, block_length=block_length, filter_width=block_width) elif attention_type == "masked_dilated_1d": x = common_attention.masked_dilated_self_attention_1d( q, k, v, block_length, block_width, gap_size, num_memory_blocks) else: assert attention_type == "unmasked_dilated_1d" x = common_attention.dilated_self_attention_1d( q, k, v, block_length, block_width, gap_size, num_memory_blocks) x = common_attention.combine_heads(x) # Set last dim specifically. x.set_shape(x.shape.as_list()[:-1] + [total_value_depth]) if vars_3d: o_var = tf.get_variable( "o", [num_heads, total_value_depth // num_heads, output_depth]) o_var = tf.cast(o_var, x.dtype) o_var = tf.reshape(o_var, [total_value_depth, output_depth]) x = tf.tensordot(x, o_var, axes=1) else: x = common_layers.dense(x, output_depth, use_bias=False, name="output_transform") if additional_returned_value is not None: return x, additional_returned_value return x
def multihead_attention(query_antecedent, memory_antecedent, bias, total_key_depth, total_value_depth, output_depth, num_heads, dropout_rate, shared_rel=False, max_relative_position=None, image_shapes=None, attention_type="dot_product", block_length=128, block_width=128, q_filter_width=1, kv_filter_width=1, q_padding="VALID", kv_padding="VALID", cache=None, gap_size=0, num_memory_blocks=2, name="multihead_attention", save_weights_to=None, make_image_summary=True, dropout_broadcast_dims=None, max_length=None, vars_3d=False, scale_dotproduct=True, **kwargs): """Multihead scaled-dot-product attention with input/output transformations. Args: query_antecedent: a Tensor with shape [batch, length_q, channels] memory_antecedent: a Tensor with shape [batch, length_m, channels] or None bias: bias Tensor (see attention_bias()) total_key_depth: an integer total_value_depth: an integer output_depth: an integer num_heads: an integer dividing total_key_depth and total_value_depth dropout_rate: a floating point number shared_rel: boolean to share relative embeddings max_relative_position: Maximum distance between inputs to generate unique relation embeddings for. Only relevant when using "dot_product_relative" attention. image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() attention_type: a string, either "dot_product", "dot_product_relative", "local_mask_right", "local_unmasked", "masked_dilated_1d", "unmasked_dilated_1d", graph, or any attention function with the signature (query, key, value, **kwargs) block_length: an integer - relevant for "local_mask_right" block_width: an integer - relevant for "local_unmasked" q_filter_width: An integer specifying how wide you want the query to be. kv_filter_width: An integer specifying how wide you want the keys and values to be. q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding. kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID": no padding. cache: dict containing Tensors which are the results of previous attentions, used for fast decoding. Expects the dict to contrain two keys ('k' and 'v'), for the initial call the values for these keys should be empty Tensors of the appropriate shape. 'k' [batch_size, 0, key_channels] 'v' [batch_size, 0, value_channels] gap_size: Integer option for dilated attention to indicate spacing between memory blocks. num_memory_blocks: Integer option to indicate how many memory blocks to look at. name: an optional string. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). make_image_summary: Whether to make an attention image summary. dropout_broadcast_dims: an optional list of integers less than 4 specifying in which dimensions to broadcast the dropout decisions. saves memory. max_length: an integer - needed by relative attention vars_3d: use 3-dimensional variables for input/output transformations scale_dotproduct: whether to normalize the attention product. **kwargs (dict): Parameters for the attention function Caching: WARNING: For decoder self-attention, i.e. when memory_antecedent == None, the caching assumes that the bias contains future masking. The caching works by saving all the previous key and value values so that you are able to send just the last query location to this attention function. I.e. if the cache dict is provided it assumes the query is of the shape [batch_size, 1, hidden_dim] rather than the full memory. Returns: The result of the attention transformation. The output shape is [batch_size, length_q, hidden_dim] unless the cache dict is provided in which case only the last memory position is calculated and the output shape is [batch_size, 1, hidden_dim] Optionally returns an additional loss parameters (ex: load balance loss for the experts) returned by the attention_type function. Raises: ValueError: if the key depth or value depth are not divisible by the number of attention heads. """ if total_key_depth % num_heads != 0: raise ValueError("Key depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_key_depth, num_heads)) if total_value_depth % num_heads != 0: raise ValueError("Value depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_value_depth, num_heads)) vars_3d_num_heads = num_heads if vars_3d else 0 with tf.variable_scope(name, default_name="multihead_attention", values=[query_antecedent, memory_antecedent]): if cache is None or memory_antecedent is None: q, k, v = common_attention.compute_qkv( query_antecedent, memory_antecedent, total_key_depth, total_value_depth, q_filter_width, kv_filter_width, q_padding, kv_padding, vars_3d_num_heads=vars_3d_num_heads) if cache is not None: if attention_type != "dot_product": # TODO(petershaw): Support caching when using relative position # representations, i.e. "dot_product_relative" attention. raise NotImplementedError( "Caching is not guaranteed to work with attention types other than" " dot_product.") if bias is None: raise ValueError("Bias required for caching. See function docstring " "for details.") if memory_antecedent is not None: # Encoder-Decoder Attention Cache q = common_attention.compute_attention_component( query_antecedent, total_key_depth, q_filter_width, q_padding, "q", vars_3d_num_heads=vars_3d_num_heads) k = cache["k_encdec"] v = cache["v_encdec"] else: k = common_attention.split_heads(k, num_heads) v = common_attention.split_heads(v, num_heads) decode_loop_step = kwargs.get("decode_loop_step") if decode_loop_step is None: k = cache["k"] = tf.concat([cache["k"], k], axis=2) v = cache["v"] = tf.concat([cache["v"], v], axis=2) else: # Inplace update is required for inference on TPU. # Inplace_ops only supports inplace_update on the first dimension. # The performance of current implementation is better than updating # the tensor by adding the result of matmul(one_hot, # update_in_current_step) tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3]) tmp_k = inplace_ops.alias_inplace_update( tmp_k, decode_loop_step, tf.squeeze(k, axis=2)) k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3]) tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3]) tmp_v = inplace_ops.alias_inplace_update( tmp_v, decode_loop_step, tf.squeeze(v, axis=2)) v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3]) q = common_attention.split_heads(q, num_heads) if cache is None: k = common_attention.split_heads(k, num_heads) v = common_attention.split_heads(v, num_heads) key_depth_per_head = total_key_depth // num_heads if not vars_3d: if scale_dotproduct: q *= key_depth_per_head**-0.5 additional_returned_value = None if callable(attention_type): # Generic way to extend multihead_attention x = attention_type(q, k, v, **kwargs) if isinstance(x, tuple): x, additional_returned_value = x # Unpack elif attention_type == "dot_product": x = common_attention.dot_product_attention( q, k, v, bias, dropout_rate, image_shapes, save_weights_to=save_weights_to, make_image_summary=make_image_summary, dropout_broadcast_dims=dropout_broadcast_dims) elif attention_type == "dot_product_relative": x = common_attention.dot_product_attention_relative( q, k, v, bias, max_relative_position, dropout_rate, image_shapes, make_image_summary=make_image_summary) elif attention_type == "dot_product_relative_v2": x = common_attention.dot_product_self_attention_relative_v2( q, k, v, bias, max_length, dropout_rate, image_shapes, make_image_summary=make_image_summary, dropout_broadcast_dims=dropout_broadcast_dims) elif attention_type == "local_within_block_mask_right": x = common_attention.masked_within_block_local_attention_1d( q, k, v, block_length=block_length) elif attention_type == "rel_local_mask_right": x = common_attention.masked_rel_local_attention_1d( q, k, v, block_length=block_length, make_image_summary=make_image_summary, dropout_rate=dropout_rate, share_rel_embed=shared_rel) elif attention_type == "local_mask_right": x = common_attention.masked_local_attention_1d( q, k, v, block_length=block_length, make_image_summary=make_image_summary) elif attention_type == "local_unmasked": x = common_attention.local_attention_1d( q, k, v, block_length=block_length, filter_width=block_width) elif attention_type == "masked_dilated_1d": x = common_attention.masked_dilated_self_attention_1d( q, k, v, block_length, block_width, gap_size, num_memory_blocks) else: assert attention_type == "unmasked_dilated_1d" x = common_attention.dilated_self_attention_1d( q, k, v, block_length, block_width, gap_size, num_memory_blocks) x = common_attention.combine_heads(x) # Set last dim specifically. x.set_shape(x.shape.as_list()[:-1] + [total_value_depth]) if vars_3d: o_var = tf.get_variable( "o", [num_heads, total_value_depth // num_heads, output_depth]) o_var = tf.cast(o_var, x.dtype) o_var = tf.reshape(o_var, [total_value_depth, output_depth]) x = tf.tensordot(x, o_var, axes=1) else: x = common_layers.dense( x, output_depth, use_bias=False, name="output_transform") if additional_returned_value is not None: return x, additional_returned_value return x
def _init_transformer_cache( cache, hparams, batch_size, attention_init_length, encoder_output, encoder_decoder_attention_bias, scope_prefix, ): """Create the initial cache for TransformerTag fast decoding.""" key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers vars_3d_num_heads = (hparams.num_heads if hparams.get('attention_variables_3d') else 0) if cache is None: cache = {} cache.update({ 'layer_%d' % layer: { 'k': common_attention.split_heads( tf.zeros([batch_size, attention_init_length, key_channels]), hparams.num_heads, ), 'v': common_attention.split_heads( tf.zeros([batch_size, attention_init_length, value_channels]), hparams.num_heads, ), } for layer in range(num_layers) }) if hparams.ffn_layer not in ['dense_relu_dense', 'conv_hidden_relu']: for layer in range(num_layers): cache['layer_%d' % layer]['f'] = tf.zeros( [batch_size, 0, hparams.hidden_size]) if encoder_output is not None: for layer in range(num_layers): layer_name = 'layer_%d' % layer with tf.variable_scope( '%sdecoder/%s/encdec_attention/multihead_attention' % (scope_prefix, layer_name)): k_encdec = common_attention.compute_attention_component( encoder_output, key_channels, name='k', vars_3d_num_heads=vars_3d_num_heads, ) k_encdec = common_attention.split_heads( k_encdec, hparams.num_heads) v_encdec = common_attention.compute_attention_component( encoder_output, value_channels, name='v', vars_3d_num_heads=vars_3d_num_heads, ) v_encdec = common_attention.split_heads( v_encdec, hparams.num_heads) cache[layer_name]['k_encdec'] = k_encdec cache[layer_name]['v_encdec'] = v_encdec cache['encoder_output'] = encoder_output cache[ 'encoder_decoder_attention_bias'] = encoder_decoder_attention_bias return cache
def fast_decode(encoder_output, encoder_decoder_attention_bias, symbols_to_logits_fn, hparams, decode_length, vocab_size, beam_size=1, top_beams=1, alpha=1.0, eos_id=beam_search.EOS_ID, batch_size=None, force_decode_length=False): if encoder_output is not None: batch_size = common_layers.shape_list(encoder_output)[0] key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers vars_3d_num_heads = (hparams.num_heads if hparams.get("attention_variables_3d") else 0) cache = { "layer_%d" % layer: { "k": common_attention.split_heads( tf.zeros([batch_size, 0, key_channels]), hparams.num_heads), "v": common_attention.split_heads( tf.zeros([batch_size, 0, value_channels]), hparams.num_heads), "f": tf.zeros([batch_size, 0, hparams.hidden_size]), } for layer in range(num_layers) } if encoder_output is not None: for layer in range(num_layers): layer_name = "layer_%d" % layer with tf.variable_scope( "body/decoder/%s/encdec_attention/multihead_attention" % layer_name): k_encdec = common_attention.compute_attention_component( encoder_output, key_channels, name="k", vars_3d_num_heads=vars_3d_num_heads) k_encdec = common_attention.split_heads( k_encdec, hparams.num_heads) v_encdec = common_attention.compute_attention_component( encoder_output, value_channels, name="v", vars_3d_num_heads=vars_3d_num_heads) v_encdec = common_attention.split_heads( v_encdec, hparams.num_heads) cache[layer_name]["k_encdec"] = k_encdec cache[layer_name]["v_encdec"] = v_encdec cache["encoder_output"] = encoder_output cache[ "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias if beam_size > 1: # Beam Search initial_ids = tf.zeros([batch_size], dtype=tf.int32) decoded_ids, scores = beam_search.beam_search( symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, states=cache, eos_id=eos_id, stop_early=(top_beams == 1)) if top_beams == 1: decoded_ids = decoded_ids[:, 0, 1:] scores = scores[:, 0] else: decoded_ids = decoded_ids[:, :top_beams, 1:] scores = scores[:, :top_beams] else: # Greedy def inner_loop(i, hit_eos, next_id, decoded_ids, cache, log_prob): """One step of greedy decoding.""" logits, cache = symbols_to_logits_fn(next_id, i, cache) log_probs = common_layers.log_prob_from_logits(logits) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = common_layers.sample_with_temperature( logits, temperature) hit_eos |= tf.equal(next_id, eos_id) log_prob_indices = tf.stack( [tf.range(tf.to_int64(batch_size)), next_id], axis=1) log_prob += tf.gather_nd(log_probs, log_prob_indices) next_id = tf.expand_dims(next_id, axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, hit_eos, next_id, decoded_ids, cache, log_prob def is_not_finished(i, hit_eos, *_): finished = i >= decode_length if not force_decode_length: finished |= tf.reduce_all(hit_eos) return tf.logical_not(finished) decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) hit_eos = tf.fill([batch_size], False) next_id = tf.zeros([batch_size, 1], dtype=tf.int64) initial_log_prob = tf.zeros([batch_size], dtype=tf.float32) _, _, _, decoded_ids, _, log_prob = tf.while_loop( is_not_finished, inner_loop, [ tf.constant(0), hit_eos, next_id, decoded_ids, cache, initial_log_prob ], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), nest.map_structure(beam_search.get_state_shape_invariants, cache), tf.TensorShape([None]), ]) scores = log_prob ##### Modified ##### # Added encoder outputs to predicion dictionary return { "outputs": decoded_ids, "encoder_outputs": encoder_output, "scores": scores }
def fast_decode(encoder_output, encoder_decoder_attention_bias, symbols_to_logits_fn, hparams, decode_length, vocab_size, beam_size=1, top_beams=1, alpha=1.0, eos_id=beam_search.EOS_ID, batch_size=None, force_decode_length=False, cache=None): """Given encoder output and a symbols to logits function, does fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: encoder_output: Output from encoder. encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention symbols_to_logits_fn: Incremental decoding; function mapping triple `(ids, step, cache)` to symbol logits. hparams: run hyperparameters decode_length: an integer. How many additional timesteps to decode. vocab_size: Output vocabulary size. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for longer translations. eos_id: End-of-sequence symbol in beam search. batch_size: an integer scalar - must be passed if there is no input force_decode_length: bool, whether to force the full decode length, or if False, stop when all beams hit eos_id. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if top_beams == 1 or [batch_size, top_beams, <= decode_length] otherwise "scores": decoding log probs from the beam search, None if using greedy decoding (beam_size=1) } Raises: NotImplementedError: If beam size > 1 with partial targets. """ if encoder_output is not None: batch_size = common_layers.shape_list(encoder_output)[0] key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers if cache is None: cache = dict() cache.update({ "layer_%d" % layer: { "k": common_attention.split_heads( tf.zeros([batch_size, 0, key_channels]), hparams.num_heads), "v": common_attention.split_heads( tf.zeros([batch_size, 0, value_channels]), hparams.num_heads), "f": tf.zeros([batch_size, 0, hparams.hidden_size]), } for layer in range(num_layers) }) if encoder_output is not None: for layer in range(num_layers): layer_name = "layer_%d" % layer with tf.variable_scope( "body/decoder/%s/encdec_attention/multihead_attention" % layer_name): k_encdec = common_attention.compute_attention_component( encoder_output, key_channels, name="k") k_encdec = common_attention.split_heads( k_encdec, hparams.num_heads) v_encdec = common_attention.compute_attention_component( encoder_output, value_channels, name="v") v_encdec = common_attention.split_heads( v_encdec, hparams.num_heads) cache[layer_name]["k_encdec"] = k_encdec cache[layer_name]["v_encdec"] = v_encdec cache["encoder_output"] = encoder_output cache[ "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias return common.fast_decode(symbols_to_logits_fn, hparams, decode_length, vocab_size, beam_size, top_beams, alpha, eos_id, batch_size, force_decode_length, cache)