def attention(query): """Put attention masks on hidden using hidden_features and query.""" ds = [] # Results of attention reads will be stored here. for a in xrange(num_heads): with vs.variable_scope("Attention_%d" % a): y = linear(query, attention_vec_size, True) y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) # Attention mask is a softmax of v^T * tanh(...). s = math_ops.reduce_sum( v[a] * math_ops.tanh(hidden_features[a] + y), [2, 3]) a = nn_ops.softmax(s) # Now calculate the attention-weighted vector d. d = math_ops.reduce_sum( array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2]) ds.append(array_ops.reshape(d, [-1, attn_size])) return ds
def attention_decoder(decoder_inputs, initial_state, attention_states, cell, output_size=None, num_heads=1, loop_function=None, dtype=dtypes.float32, scope=None): """RNN decoder with attention for the sequence-to-sequence model. Args: decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. initial_state: 2D Tensor [batch_size x cell.state_size]. attention_states: 3D Tensor [batch_size x attn_length x attn_size]. cell: RNNCell defining the cell function and size. output_size: size of the output vectors; if None, we use cell.output_size. num_heads: number of attention heads that read from attention_states. loop_function: if not None, this function will be applied to i-th output in order to generate i+1-th input, and decoder_inputs will be ignored, except for the first element ("GO" symbol). This can be used for decoding, but also for training to emulate http://arxiv.org/pdf/1506.03099v2.pdf. Signature -- loop_function(prev, i) = next * prev is a 2D Tensor of shape [batch_size x cell.output_size], * i is an integer, the step number (when advanced control is needed), * next is a 2D Tensor of shape [batch_size x cell.input_size]. dtype: The dtype to use for the RNN initial state (default: tf.float32). scope: VariableScope for the created subgraph; default: "attention_decoder". Returns: outputs: A list of the same length as decoder_inputs of 2D Tensors of shape [batch_size x output_size]. These represent the generated outputs. Output i is computed from input i (which is either i-th decoder_inputs or loop_function(output {i-1}, i)) as follows. First, we run the cell on a combination of the input and previous attention masks: cell_output, new_state = cell(linear(input, prev_attn), prev_state). Then, we calculate new attention masks: new_attn = softmax(V^T * tanh(W * attention_states + U * new_state)) and then we calculate the output: output = linear(cell_output, new_attn). states: The state of each decoder cell in each time-step. This is a list with length len(decoder_inputs) -- one item for each time-step. Each item is a 2D Tensor of shape [batch_size x cell.state_size]. Raises: ValueError: when num_heads is not positive, there are no inputs, or shapes of attention_states are not set. """ if not decoder_inputs: raise ValueError("Must provide at least 1 input to attention decoder.") if num_heads < 1: raise ValueError("With less than 1 heads, use a non-attention decoder.") if not attention_states.get_shape()[1:2].is_fully_defined(): raise ValueError("Shape[1] and [2] of attention_states must be known: %s" % attention_states.get_shape()) if output_size is None: output_size = cell.output_size with vs.variable_scope(scope or "attention_decoder"): batch_size = array_ops.shape(decoder_inputs[0])[0] # Needed for reshaping. attn_length = attention_states.get_shape()[1].value attn_size = attention_states.get_shape()[2].value # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before. hidden = array_ops.reshape( attention_states, [-1, attn_length, 1, attn_size]) hidden_features = [] v = [] attention_vec_size = attn_size # Size of query vectors for attention. for a in xrange(num_heads): k = vs.get_variable("AttnW_%d" % a, [1, 1, attn_size, attention_vec_size]) hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")) v.append(vs.get_variable("AttnV_%d" % a, [attention_vec_size])) states = [initial_state] def attention(query): """Put attention masks on hidden using hidden_features and query.""" ds = [] # Results of attention reads will be stored here. for a in xrange(num_heads): with vs.variable_scope("Attention_%d" % a): y = linear(query, attention_vec_size, True) y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) # Attention mask is a softmax of v^T * tanh(...). s = math_ops.reduce_sum( v[a] * math_ops.tanh(hidden_features[a] + y), [2, 3]) a = nn_ops.softmax(s) # Now calculate the attention-weighted vector d. d = math_ops.reduce_sum( array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2]) ds.append(array_ops.reshape(d, [-1, attn_size])) return ds outputs = [] prev = None batch_attn_size = array_ops.stack([batch_size, attn_size]) attns = [array_ops.zeros(batch_attn_size, dtype=dtype) for _ in xrange(num_heads)] for a in attns: # Ensure the second shape of attention vectors is set. a.set_shape([None, attn_size]) for i in xrange(len(decoder_inputs)): if i > 0: vs.get_variable_scope().reuse_variables() inp = decoder_inputs[i] # If loop_function is set, we use it instead of decoder_inputs. if loop_function is not None and prev is not None: with vs.variable_scope("loop_function", reuse=True): inp = array_ops.stop_gradient(loop_function(prev, i)) # Merge input and previous attentions into one vector of the right size. input_size = inp.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from input: %s" % inp.name) x = linear([inp] + attns, input_size, True) # Run the RNN. cell_output, new_state = cell(x, states[-1]) states.append(new_state) query = new_state # flatten the dimensions in multi-layer LSTMs (concatenate all) if isinstance(new_state, tuple) and isinstance(new_state[0], tuple): query = array_ops.transpose(array_ops.concat(new_state, axis=0), [1, 0, 2]) query = array_ops.reshape(query, [-1, int(query.get_shape()[1] * query.get_shape()[2])]) # Run the attention mechanism. attns = attention(query) with vs.variable_scope("AttnOutputProjection"): output = linear([cell_output] + attns, output_size, True) if loop_function is not None: # We do not propagate gradients over the loop function. prev = array_ops.stop_gradient(output) outputs.append(output) return outputs, states