def _time_step(time, input_, state, output, proj_outputs, decoder_outputs, samples, states, weights, prev_weights): context_vector, new_weights = attention_(state, prev_weights=prev_weights) weights = weights.write(time, new_weights) # FIXME use `output` or `state` here? output_ = linear_unsafe([state, input_, context_vector], decoder.cell_size, False, scope='maxout') output_ = tf.reduce_max(tf.reshape(output_, tf.stack([batch_size, decoder.cell_size // 2, 2])), axis=2) output_ = linear_unsafe(output_, decoder.embedding_size, False, scope='softmax0') decoder_outputs = decoder_outputs.write(time, output_) output_ = linear_unsafe(output_, output_size, True, scope='softmax1') proj_outputs = proj_outputs.write(time, output_) argmax = lambda: tf.argmax(output_, 1) softmax = lambda: tf.squeeze(tf.multinomial(tf.log(tf.nn.softmax(output_)), num_samples=1), axis=1) target = lambda: inputs.read(time + 1) sample = tf.case([ (tf.logical_and(time < time_steps - 1, tf.random_uniform([]) >= feed_previous), target), (tf.logical_not(feed_argmax), softmax)], default=argmax) # default case is useful for beam-search sample.set_shape([None]) sample = tf.stop_gradient(sample) samples = samples.write(time, sample) input_ = embed(sample) x = tf.concat([input_, context_vector], 1) call_cell = lambda: unsafe_decorator(cell)(x, state) if sequence_length is not None: new_output, new_state = rnn._rnn_step( time=time, sequence_length=sequence_length, min_sequence_length=min_sequence_length, max_sequence_length=max_sequence_length, zero_output=zero_output, state=state, call_cell=call_cell, state_size=state_size, skip_conditionals=True) else: new_output, new_state = call_cell() states = states.write(time, new_state) return (time + 1, input_, new_state, new_output, proj_outputs, decoder_outputs, samples, states, weights, new_weights)
def compute_energy_with_filter(hidden, state, prev_weights, attention_filters, attention_filter_length, **kwargs): time_steps = tf.shape(hidden)[1] attn_size = hidden.get_shape()[3].value batch_size = tf.shape(hidden)[0] filter_shape = [attention_filter_length * 2 + 1, 1, 1, attention_filters] filter_ = get_variable_unsafe('filter', filter_shape) u = get_variable_unsafe('U', [attention_filters, attn_size]) prev_weights = tf.reshape(prev_weights, tf.stack([batch_size, time_steps, 1, 1])) conv = tf.nn.conv2d(prev_weights, filter_, [1, 1, 1, 1], 'SAME') shape = tf.stack([tf.multiply(batch_size, time_steps), attention_filters]) conv = tf.reshape(conv, shape) z = tf.matmul(conv, u) z = tf.reshape(z, tf.stack([batch_size, time_steps, 1, attn_size])) y = linear_unsafe(state, attn_size, True) y = tf.reshape(y, [-1, 1, 1, attn_size]) k = get_variable_unsafe('W', [attn_size, attn_size]) # dot product between tensors requires reshaping hidden = tf.reshape(hidden, tf.stack([tf.multiply(batch_size, time_steps), attn_size])) f = tf.matmul(hidden, k) f = tf.reshape(f, tf.stack([batch_size, time_steps, 1, attn_size])) v = get_variable_unsafe('V', [attn_size]) s = f + y + z return tf.reduce_sum(v * tf.tanh(s), [2, 3])
def compute_energy(hidden, state, attn_size, **kwargs): input_size = hidden.get_shape()[3].value batch_size = tf.shape(hidden)[0] time_steps = tf.shape(hidden)[1] # initializer = tf.random_normal_initializer(stddev=0.001) # same as Bahdanau et al. initializer = None y = linear_unsafe(state, attn_size, True, scope='W_a', initializer=initializer) y = tf.reshape(y, [-1, 1, attn_size]) k = get_variable_unsafe('U_a', [input_size, attn_size], initializer=initializer) # dot product between tensors requires reshaping hidden = tf.reshape( hidden, tf.stack([tf.multiply(batch_size, time_steps), input_size])) f = tf.matmul(hidden, k) f = tf.reshape(f, tf.stack([batch_size, time_steps, attn_size])) v = get_variable_unsafe('v_a', [attn_size]) s = f + y return tf.reduce_sum(v * tf.tanh(s), [2])
def compute_energy(hidden, state, name, **kwargs): attn_size = hidden.get_shape()[3].value batch_size = tf.shape(hidden)[0] time_steps = tf.shape(hidden)[1] y = linear_unsafe(state, attn_size, True, scope=name) y = tf.reshape(y, [-1, 1, 1, attn_size]) k = get_variable_unsafe('W_{}'.format(name), [attn_size, attn_size]) # dot product between tensors requires reshaping hidden = tf.reshape(hidden, tf.pack([tf.mul(batch_size, time_steps), attn_size])) f = tf.matmul(hidden, k) f = tf.reshape(f, tf.pack([batch_size, time_steps, 1, attn_size])) v = get_variable_unsafe('V_{}'.format(name), [attn_size]) s = f + y return tf.reduce_sum(v * tf.tanh(s), [2, 3])
def _time_step(time, state, _, attn_weights, output_ta_t, attn_weights_ta_t): input_t = input_ta.read(time) # restore some shape information r = tf.random_uniform([]) input_t = tf.cond(tf.logical_and(time > 0, r < feed_previous), lambda: tf.stop_gradient(extract_argmax_and_embed(output_ta_t.read(time - 1))), lambda: input_t) input_t.set_shape(decoder_inputs.get_shape()[1:]) # the code from TensorFlow used a concatenation of input_t and attns as input here # TODO: evaluate the impact of this call_cell = lambda: unsafe_decorator(cell)(input_t, state) if sequence_length is not None: output, new_state = rnn._rnn_step( time=time, sequence_length=sequence_length, min_sequence_length=min_sequence_length, max_sequence_length=max_sequence_length, zero_output=zero_output, state=state, call_cell=call_cell, state_size=state_size, skip_conditionals=True) else: output, new_state = call_cell() attn_weights_ta_t = attn_weights_ta_t.write(time, attn_weights) # using decoder state instead of decoder output in the attention model seems # to give much better results new_attns, new_attn_weights = attention_(new_state, prev_weights=attn_weights) with tf.variable_scope('attention_output_projection'): # this can take a lot of memory output = linear_unsafe([output, new_attns], output_size, True) output_ta_t = output_ta_t.write(time, output) return time + 1, new_state, new_attns, new_attn_weights, output_ta_t, attn_weights_ta_t
def attention_decoder(targets, initial_state, attention_states, encoders, decoder, encoder_input_length, decoder_input_length=None, dropout=None, feed_previous=0.0, feed_argmax=True, **kwargs): """ :param targets: tensor of shape (output_length, batch_size) :param initial_state: initial state of the decoder (usually the final state of the encoder), as a tensor of shape (batch_size, initial_state_size). This state is mapped to the correct state size for the decoder. :param attention_states: list of tensors of shape (batch_size, input_length, encoder_cell_size), usually the encoder outputs (one tensor for each encoder). :param encoders: configuration of the encoders :param decoder: configuration of the decoder :param decoder_input_length: :param dropout: scalar tensor or None, specifying the keep probability (1 - dropout) :param feed_previous: scalar tensor corresponding to the probability to use previous decoder output instead of the groundtruth as input for the decoder (1 when decoding, between 0 and 1 when training) :return: outputs of the decoder as a tensor of shape (batch_size, output_length, decoder_cell_size) attention weights as a tensor of shape (output_length, encoders, batch_size, input_length) """ # TODO: dropout instead of keep probability assert decoder.cell_size % 2 == 0, 'cell size must be a multiple of 2' # because of maxout decoder_inputs = targets[:-1,:] # starts with BOS if decoder.get('embedding') is not None: initializer = decoder.embedding embedding_shape = None else: initializer = None embedding_shape = [decoder.vocab_size, decoder.embedding_size] with tf.device('/cpu:0'): embedding = get_variable_unsafe('embedding_{}'.format(decoder.name), shape=embedding_shape, initializer=initializer) if decoder.use_lstm: cell = BasicLSTMCell(decoder.cell_size, state_is_tuple=False) else: cell = GRUCell(decoder.cell_size, initializer=orthogonal_initializer()) if dropout is not None: cell = DropoutWrapper(cell, input_keep_prob=dropout) if decoder.layers > 1: cell = MultiRNNCell([cell] * decoder.layers, residual_connections=decoder.residual_connections) with tf.variable_scope('decoder_{}'.format(decoder.name)): def embed(input_): if embedding is not None: return tf.nn.embedding_lookup(embedding, input_) else: return input_ hidden_states = [tf.expand_dims(states, 2) for states in attention_states] attention_ = functools.partial(multi_attention, hidden_states=hidden_states, encoders=encoders, encoder_input_length=encoder_input_length) input_shape = tf.shape(decoder_inputs) time_steps = input_shape[0] batch_size = input_shape[1] output_size = decoder.vocab_size state_size = cell.state_size if initial_state is not None: if dropout is not None: initial_state = tf.nn.dropout(initial_state, dropout) state = tf.nn.tanh( linear_unsafe(initial_state, state_size, True, scope='initial_state_projection') ) else: # if not initial state, initialize with zeroes (this is the case for MIXER) state = tf.zeros([batch_size, state_size], dtype=tf.float32) sequence_length = decoder_input_length if sequence_length is not None: sequence_length = tf.to_int32(sequence_length) min_sequence_length = tf.reduce_min(sequence_length) max_sequence_length = tf.reduce_max(sequence_length) time = tf.constant(0, dtype=tf.int32, name='time') zero_output = tf.zeros(tf.stack([batch_size, cell.output_size]), tf.float32) proj_outputs = tf.TensorArray(dtype=tf.float32, size=time_steps, clear_after_read=False) decoder_outputs = tf.TensorArray(dtype=tf.float32, size=time_steps) inputs = tf.TensorArray(dtype=tf.int64, size=time_steps, clear_after_read=False).unstack( tf.cast(decoder_inputs, tf.int64)) samples = tf.TensorArray(dtype=tf.int64, size=time_steps, clear_after_read=False) states = tf.TensorArray(dtype=tf.float32, size=time_steps) attn_lengths = [tf.shape(states)[1] for states in attention_states] weights = tf.TensorArray(dtype=tf.float32, size=time_steps) initial_weights = [tf.zeros(tf.stack([batch_size, length])) for length in attn_lengths] output = tf.zeros(tf.stack([batch_size, cell.output_size]), dtype=tf.float32) initial_input = embed(inputs.read(0)) # first symbol is BOS def _time_step(time, input_, state, output, proj_outputs, decoder_outputs, samples, states, weights, prev_weights): context_vector, new_weights = attention_(state, prev_weights=prev_weights) weights = weights.write(time, new_weights) # FIXME use `output` or `state` here? output_ = linear_unsafe([state, input_, context_vector], decoder.cell_size, False, scope='maxout') output_ = tf.reduce_max(tf.reshape(output_, tf.stack([batch_size, decoder.cell_size // 2, 2])), axis=2) output_ = linear_unsafe(output_, decoder.embedding_size, False, scope='softmax0') decoder_outputs = decoder_outputs.write(time, output_) output_ = linear_unsafe(output_, output_size, True, scope='softmax1') proj_outputs = proj_outputs.write(time, output_) argmax = lambda: tf.argmax(output_, 1) softmax = lambda: tf.squeeze(tf.multinomial(tf.log(tf.nn.softmax(output_)), num_samples=1), axis=1) target = lambda: inputs.read(time + 1) sample = tf.case([ (tf.logical_and(time < time_steps - 1, tf.random_uniform([]) >= feed_previous), target), (tf.logical_not(feed_argmax), softmax)], default=argmax) # default case is useful for beam-search sample.set_shape([None]) sample = tf.stop_gradient(sample) samples = samples.write(time, sample) input_ = embed(sample) x = tf.concat([input_, context_vector], 1) call_cell = lambda: unsafe_decorator(cell)(x, state) if sequence_length is not None: new_output, new_state = rnn._rnn_step( time=time, sequence_length=sequence_length, min_sequence_length=min_sequence_length, max_sequence_length=max_sequence_length, zero_output=zero_output, state=state, call_cell=call_cell, state_size=state_size, skip_conditionals=True) else: new_output, new_state = call_cell() states = states.write(time, new_state) return (time + 1, input_, new_state, new_output, proj_outputs, decoder_outputs, samples, states, weights, new_weights) _, _, new_state, new_output, proj_outputs, decoder_outputs, samples, states, weights, _ = tf.while_loop( cond=lambda time, *_: time < time_steps, body=_time_step, loop_vars=(time, initial_input, state, output, proj_outputs, decoder_outputs, samples, weights, states, initial_weights), parallel_iterations=decoder.parallel_iterations, swap_memory=decoder.swap_memory) proj_outputs = proj_outputs.stack() decoder_outputs = decoder_outputs.stack() samples = samples.stack() weights = weights.stack() # batch_size, encoders, output time, input time states = states.stack() # weights = tf.Print(weights, [weights[:,0]], summarize=20) # tf.control_dependencies() beam_tensors = namedtuple('beam_tensors', 'state new_state output new_output') return (proj_outputs, weights, decoder_outputs, beam_tensors(state, new_state, output, new_output), samples, states)
def multi_encoder(encoder_inputs, encoders, encoder_input_length, dropout=None, **kwargs): """ Build multiple encoders according to the configuration in `encoders`, reading from `encoder_inputs`. The result is a list of the outputs produced by those encoders (for each time-step), and their final state. :param encoder_inputs: list of tensors of shape (batch_size, input_length) (one tensor for each encoder) :param encoders: list of encoder configurations :param encoder_input_length: list of tensors of shape (batch_size) (one tensor for each encoder) :param dropout: scalar tensor or None, specifying the keep probability (1 - dropout) :return: encoder outputs: a list of tensors of shape (batch_size, input_length, encoder_cell_size) encoder state: concatenation of the final states of all encoders, tensor of shape (batch_size, sum_of_state_sizes) """ assert len(encoder_inputs) == len(encoders) encoder_states = [] encoder_outputs = [] # create embeddings in the global scope (allows sharing between encoder and decoder) embedding_variables = [] for encoder in encoders: # inputs are token ids, which need to be mapped to vectors (embeddings) if not encoder.binary: if encoder.get('embedding') is not None: initializer = encoder.embedding embedding_shape = None else: # initializer = tf.random_uniform_initializer(-math.sqrt(3), math.sqrt(3)) initializer = None embedding_shape = [encoder.vocab_size, encoder.embedding_size] with tf.device('/cpu:0'): embedding = get_variable_unsafe('embedding_{}'.format(encoder.name), shape=embedding_shape, initializer=initializer) embedding_variables.append(embedding) else: # do nothing: inputs are already vectors embedding_variables.append(None) for i, encoder in enumerate(encoders): with tf.variable_scope('encoder_{}'.format(encoder.name)): encoder_inputs_ = encoder_inputs[i] encoder_input_length_ = encoder_input_length[i] # TODO: use state_is_tuple=True if encoder.use_lstm: cell = BasicLSTMCell(encoder.cell_size, state_is_tuple=False) else: cell = GRUCell(encoder.cell_size, initializer=orthogonal_initializer()) if dropout is not None: cell = DropoutWrapper(cell, input_keep_prob=dropout) embedding = embedding_variables[i] if embedding is not None or encoder.input_layers: batch_size = tf.shape(encoder_inputs_)[0] # TODO: fix this time major stuff time_steps = tf.shape(encoder_inputs_)[1] if embedding is None: size = encoder_inputs_.get_shape()[2].value flat_inputs = tf.reshape(encoder_inputs_, [tf.multiply(batch_size, time_steps), size]) else: flat_inputs = tf.reshape(encoder_inputs_, [tf.multiply(batch_size, time_steps)]) flat_inputs = tf.nn.embedding_lookup(embedding, flat_inputs) if encoder.input_layers: for j, size in enumerate(encoder.input_layers): name = 'input_layer_{}'.format(j) flat_inputs = tf.nn.tanh(linear_unsafe(flat_inputs, size, bias=True, scope=name)) if dropout is not None: flat_inputs = tf.nn.dropout(flat_inputs, dropout) encoder_inputs_ = tf.reshape(flat_inputs, tf.stack([batch_size, time_steps, flat_inputs.get_shape()[1].value])) # Contrary to Theano's RNN implementation, states after the sequence length are zero # (while Theano repeats last state) sequence_length = encoder_input_length_ # TODO parameters = dict( inputs=encoder_inputs_, sequence_length=sequence_length, time_pooling=encoder.time_pooling, pooling_avg=encoder.pooling_avg, dtype=tf.float32, swap_memory=encoder.swap_memory, parallel_iterations=encoder.parallel_iterations, residual_connections=encoder.residual_connections, trainable_initial_state=True ) if encoder.bidir: encoder_outputs_, _, _ = multi_bidirectional_rnn_unsafe( cells=[(cell, cell)] * encoder.layers, **parameters) # Like Bahdanau et al., we use the first annotation h_1 of the backward encoder encoder_state_ = encoder_outputs_[:, 0, encoder.cell_size:] # TODO: if multiple layers, combine last states with a Maxout layer else: encoder_outputs_, encoder_state_ = multi_rnn_unsafe( cells=[cell] * encoder.layers, **parameters) encoder_state_ = encoder_outputs_[:, -1, :] encoder_outputs.append(encoder_outputs_) encoder_states.append(encoder_state_) encoder_state = tf.concat(encoder_states, 1) return encoder_outputs, encoder_state
def multi_encoder(encoder_inputs, encoders, encoder_input_length, dropout=None, **kwargs): """ Build multiple encoders according to the configuration in `encoders`, reading from `encoder_inputs`. The result is a list of the outputs produced by those encoders (for each time-step), and their final state. :param encoder_inputs: list of tensors of shape (batch_size, input_length) (one tensor for each encoder) :param encoders: list of encoder configurations :param encoder_input_length: list of tensors of shape (batch_size) (one tensor for each encoder) :param dropout: scalar tensor or None, specifying the keep probability (1 - dropout) :return: encoder outputs: a list of tensors of shape (batch_size, input_length, encoder_cell_size) encoder state: concatenation of the final states of all encoders, tensor of shape (batch_size, sum_of_state_sizes) """ assert len(encoder_inputs) == len(encoders) encoder_states = [] encoder_outputs = [] # create embeddings in the global scope (allows sharing between encoder and decoder) embedding_variables = [] for encoder in encoders: # inputs are token ids, which need to be mapped to vectors (embeddings) if not encoder.binary: if encoder.get('embedding') is not None: initializer = encoder.embedding embedding_shape = None else: initializer = tf.random_uniform_initializer(-math.sqrt(3), math.sqrt(3)) embedding_shape = [encoder.vocab_size, encoder.embedding_size] with tf.device('/cpu:0'): embedding = get_variable_unsafe('embedding_{}'.format(encoder.name), shape=embedding_shape, initializer=initializer) embedding_variables.append(embedding) else: # do nothing: inputs are already vectors embedding_variables.append(None) with tf.variable_scope('multi_encoder'): for i, encoder in enumerate(encoders): with tf.variable_scope(encoder.name): encoder_inputs_ = encoder_inputs[i] encoder_input_length_ = encoder_input_length[i] # TODO: use state_is_tuple=True if encoder.use_lstm: cell = rnn_cell.BasicLSTMCell(encoder.cell_size, state_is_tuple=False) else: cell = rnn_cell.GRUCell(encoder.cell_size) if dropout is not None: cell = rnn_cell.DropoutWrapper(cell, input_keep_prob=dropout) embedding = embedding_variables[i] if embedding is not None or encoder.input_layers: batch_size = tf.shape(encoder_inputs_)[0] # TODO: fix this time major stuff time_steps = tf.shape(encoder_inputs_)[1] if embedding is None: size = encoder_inputs_.get_shape()[2].value flat_inputs = tf.reshape(encoder_inputs_, [tf.mul(batch_size, time_steps), size]) else: flat_inputs = tf.reshape(encoder_inputs_, [tf.mul(batch_size, time_steps)]) flat_inputs = tf.nn.embedding_lookup(embedding, flat_inputs) if encoder.input_layers: for j, size in enumerate(encoder.input_layers): name = 'input_layer_{}'.format(j) flat_inputs = tf.nn.tanh(linear_unsafe(flat_inputs, size, bias=True, scope=name)) if dropout is not None: flat_inputs = tf.nn.dropout(flat_inputs, dropout) encoder_inputs_ = tf.reshape(flat_inputs, tf.pack([batch_size, time_steps, flat_inputs.get_shape()[1].value])) sequence_length = encoder_input_length_ parameters = dict( inputs=encoder_inputs_, sequence_length=sequence_length, time_pooling=encoder.time_pooling, pooling_avg=encoder.pooling_avg, dtype=tf.float32, swap_memory=encoder.swap_memory, parallel_iterations=encoder.parallel_iterations, residual_connections=encoder.residual_connections ) if encoder.bidir: encoder_outputs_, _, encoder_state_ = multi_bidirectional_rnn_unsafe( cells=[(cell, cell)] * encoder.layers, **parameters) else: encoder_outputs_, encoder_state_ = multi_rnn_unsafe( cells=[cell] * encoder.layers, **parameters) if encoder.bidir: # map to correct output dimension # there is no tensor product operation, so we need to flatten our tensor to # a matrix to perform a dot product shape = tf.shape(encoder_outputs_) batch_size = shape[0] time_steps = shape[1] dim = encoder_outputs_.get_shape()[2] outputs_ = tf.reshape(encoder_outputs_, tf.pack([tf.mul(batch_size, time_steps), dim])) outputs_ = linear_unsafe(outputs_, cell.output_size, False, scope='bidir_projection') encoder_outputs_ = tf.reshape(outputs_, tf.pack([batch_size, time_steps, cell.output_size])) encoder_outputs.append(encoder_outputs_) encoder_states.append(encoder_state_) encoder_state = tf.concat(1, encoder_states) return encoder_outputs, encoder_state
def beam_search_decoder(decoder_input, initial_state, attention_states, encoders, decoder, output_projection=None, dropout=None, **kwargs): """ Same as `attention_decoder`, except that it only performs one step of the decoder. :param decoder_input: tensor of size (batch_size), corresponding to the previous output of the decoder :return: current output of the decoder tuple of (state, new_state, attn_weights, new_attn_weights, attns, new_attns) """ # TODO: code refactoring with `attention_decoder` if decoder.get('embedding') is not None: embedding_initializer = decoder.embedding embedding_shape = None else: embedding_initializer = None embedding_shape = [decoder.vocab_size, decoder.embedding_size] with tf.device('/cpu:0'): embedding = get_variable_unsafe('embedding_{}'.format(decoder.name), shape=embedding_shape, initializer=embedding_initializer) if decoder.use_lstm: cell = rnn_cell.BasicLSTMCell(decoder.cell_size, state_is_tuple=False) else: cell = rnn_cell.GRUCell(decoder.cell_size) if dropout is not None: cell = rnn_cell.DropoutWrapper(cell, input_keep_prob=dropout) if decoder.layers > 1: cell = MultiRNNCell([cell] * decoder.layers, residual_connections=decoder.residual_connections) if output_projection is None: output_size = decoder.vocab_size else: output_size = cell.output_size proj_weights = tf.convert_to_tensor(output_projection[0], dtype=tf.float32) proj_weights.get_shape().assert_is_compatible_with([cell.output_size, decoder.vocab_size]) proj_biases = tf.convert_to_tensor(output_projection[1], dtype=tf.float32) proj_biases.get_shape().assert_is_compatible_with([decoder.vocab_size]) with tf.variable_scope('decoder_{}'.format(decoder.name)): decoder_input = tf.nn.embedding_lookup(embedding, decoder_input) attn_lengths = [tf.shape(states)[1] for states in attention_states] attn_size = sum(states.get_shape()[2].value for states in attention_states) hidden_states = [tf.expand_dims(states, 2) for states in attention_states] attention_ = functools.partial(multi_attention, hidden_states=hidden_states, encoders=encoders) if dropout is not None: initial_state = tf.nn.dropout(initial_state, dropout) state = tf.nn.tanh( linear_unsafe(initial_state, cell.state_size, False, scope='initial_state_projection') ) batch_size = tf.shape(decoder_input)[0] attn_weights = [tf.zeros(tf.pack([batch_size, length])) for length in attn_lengths] attns = tf.zeros(tf.pack([batch_size, attn_size]), dtype=tf.float32) cell_output, new_state = unsafe_decorator(cell)(decoder_input, state) new_attns, new_attn_weights = attention_(new_state, prev_weights=attn_weights) with tf.variable_scope('attention_output_projection'): output = linear_unsafe([cell_output, new_attns], output_size, True) beam_tensors = namedtuple('beam_tensors', 'state new_state attn_weights new_attn_weights attns new_attns') return output, beam_tensors(state, new_state, attn_weights, new_attn_weights, attns, new_attns)
def attention_decoder(decoder_inputs, initial_state, attention_states, encoders, decoder, decoder_input_length=None, output_projection=None, dropout=None, feed_previous=0.0, **kwargs): """ :param decoder_inputs: tensor of shape (batch_size, output_length) :param initial_state: initial state of the decoder (usually the final state of the encoder), as a tensor of shape (batch_size, initial_state_size). This state is mapped to the correct state size for the decoder. :param attention_states: list of tensors of shape (batch_size, input_length, encoder_cell_size), usually the encoder outputs (one tensor for each encoder). :param encoders: configuration of the encoders :param decoder: configuration of the decoder :param decoder_input_length: :param output_projection: None if no softmax sampling, or tuple (weight matrix, bias vector) :param dropout: scalar tensor or None, specifying the keep probability (1 - dropout) :param feed_previous: scalar tensor corresponding to the probability to use previous decoder output instead of the groundtruth as input for the decoder (1 when decoding, between 0 and 1 when training) :return: outputs of the decoder as a tensor of shape (batch_size, output_length, decoder_cell_size) attention weights as a tensor of shape (output_length, encoders, batch_size, input_length) """ # TODO: dropout instead of keep probability if decoder.get('embedding') is not None: embedding_initializer = decoder.embedding embedding_shape = None else: embedding_initializer = None embedding_shape = [decoder.vocab_size, decoder.embedding_size] with tf.device('/cpu:0'): embedding = get_variable_unsafe('embedding_{}'.format(decoder.name), shape=embedding_shape, initializer=embedding_initializer) if decoder.use_lstm: cell = rnn_cell.BasicLSTMCell(decoder.cell_size, state_is_tuple=False) else: cell = rnn_cell.GRUCell(decoder.cell_size) if dropout is not None: cell = rnn_cell.DropoutWrapper(cell, input_keep_prob=dropout) if decoder.layers > 1: cell = MultiRNNCell([cell] * decoder.layers, residual_connections=decoder.residual_connections) if output_projection is None: output_size = decoder.vocab_size else: output_size = cell.output_size proj_weights = tf.convert_to_tensor(output_projection[0], dtype=tf.float32) proj_weights.get_shape().assert_is_compatible_with([cell.output_size, decoder.vocab_size]) proj_biases = tf.convert_to_tensor(output_projection[1], dtype=tf.float32) proj_biases.get_shape().assert_is_compatible_with([decoder.vocab_size]) with tf.variable_scope('decoder_{}'.format(decoder.name)): def extract_argmax_and_embed(prev): if output_projection is not None: prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1]) prev_symbol = tf.stop_gradient(tf.argmax(prev, 1)) emb_prev = tf.nn.embedding_lookup(embedding, prev_symbol) return emb_prev if embedding is not None: time_steps = tf.shape(decoder_inputs)[0] batch_size = tf.shape(decoder_inputs)[1] flat_inputs = tf.reshape(decoder_inputs, [tf.mul(batch_size, time_steps)]) flat_inputs = tf.nn.embedding_lookup(embedding, flat_inputs) decoder_inputs = tf.reshape(flat_inputs, tf.pack([time_steps, batch_size, flat_inputs.get_shape()[1].value])) attn_lengths = [tf.shape(states)[1] for states in attention_states] attn_size = sum(states.get_shape()[2].value for states in attention_states) hidden_states = [tf.expand_dims(states, 2) for states in attention_states] attention_ = functools.partial(multi_attention, hidden_states=hidden_states, encoders=encoders) if dropout is not None: initial_state = tf.nn.dropout(initial_state, dropout) state = tf.nn.tanh( linear_unsafe(initial_state, cell.state_size, False, scope='initial_state_projection') ) sequence_length = decoder_input_length if sequence_length is not None: sequence_length = tf.to_int32(sequence_length) min_sequence_length = tf.reduce_min(sequence_length) max_sequence_length = tf.reduce_max(sequence_length) time = tf.constant(0, dtype=tf.int32, name='time') input_shape = tf.shape(decoder_inputs) time_steps = input_shape[0] batch_size = input_shape[1] state_size = cell.state_size zero_output = tf.zeros(tf.pack([batch_size, cell.output_size]), tf.float32) output_ta = tf.TensorArray(dtype=tf.float32, size=time_steps, clear_after_read=False) input_ta = tf.TensorArray(dtype=tf.float32, size=time_steps).unpack(decoder_inputs) attn_weights_ta = tf.TensorArray(dtype=tf.float32, size=time_steps) attention_weights = [tf.zeros(tf.pack([batch_size, length])) for length in attn_lengths] attns = tf.zeros(tf.pack([batch_size, attn_size]), dtype=tf.float32) attns.set_shape([None, attn_size]) def _time_step(time, state, _, attn_weights, output_ta_t, attn_weights_ta_t): input_t = input_ta.read(time) # restore some shape information r = tf.random_uniform([]) input_t = tf.cond(tf.logical_and(time > 0, r < feed_previous), lambda: tf.stop_gradient(extract_argmax_and_embed(output_ta_t.read(time - 1))), lambda: input_t) input_t.set_shape(decoder_inputs.get_shape()[1:]) # the code from TensorFlow used a concatenation of input_t and attns as input here # TODO: evaluate the impact of this call_cell = lambda: unsafe_decorator(cell)(input_t, state) if sequence_length is not None: output, new_state = rnn._rnn_step( time=time, sequence_length=sequence_length, min_sequence_length=min_sequence_length, max_sequence_length=max_sequence_length, zero_output=zero_output, state=state, call_cell=call_cell, state_size=state_size, skip_conditionals=True) else: output, new_state = call_cell() attn_weights_ta_t = attn_weights_ta_t.write(time, attn_weights) # using decoder state instead of decoder output in the attention model seems # to give much better results new_attns, new_attn_weights = attention_(new_state, prev_weights=attn_weights) with tf.variable_scope('attention_output_projection'): # this can take a lot of memory output = linear_unsafe([output, new_attns], output_size, True) output_ta_t = output_ta_t.write(time, output) return time + 1, new_state, new_attns, new_attn_weights, output_ta_t, attn_weights_ta_t _, _, _, _, output_final_ta, attn_weights_final = tf.while_loop( cond=lambda time, *_: time < time_steps, body=_time_step, loop_vars=(time, state, attns, attention_weights, output_ta, attn_weights_ta), parallel_iterations=decoder.parallel_iterations, swap_memory=decoder.swap_memory) outputs = output_final_ta.pack() # shape (time_steps, encoders, batch_size, input_time_steps) attention_weights = tf.slice(attn_weights_final.pack(), [1, 0, 0, 0], [-1, -1, -1, -1]) return outputs, attention_weights