def get_cell(input_size=None, reuse=False): cells = [] for j in range(decoder.layers): input_size_ = input_size if j == 0 else decoder.cell_size if decoder.cell_type.lower() == 'lstm': cell = CellWrapper(BasicLSTMCell(decoder.cell_size, reuse=reuse)) elif decoder.cell_type.lower() == 'dropoutgru': cell = DropoutGRUCell(decoder.cell_size, reuse=reuse, layer_norm=decoder.layer_norm, input_size=input_size_, input_keep_prob=decoder.rnn_input_keep_prob, state_keep_prob=decoder.rnn_state_keep_prob) else: cell = GRUCell(decoder.cell_size, reuse=reuse, layer_norm=decoder.layer_norm) if decoder.use_dropout and decoder.cell_type.lower() != 'dropoutgru': cell = DropoutWrapper(cell, input_keep_prob=decoder.rnn_input_keep_prob, output_keep_prob=decoder.rnn_output_keep_prob, state_keep_prob=decoder.rnn_state_keep_prob, variational_recurrent=decoder.pervasive_dropout, dtype=tf.float32, input_size=input_size_) cells.append(cell) if len(cells) == 1: return cells[0] else: return CellWrapper(MultiRNNCell(cells))
def cell(): if encoder.use_lstm: cell = BasicLSTMCell(encoder.cell_size, state_is_tuple=False) else: cell = GRUCell(encoder.cell_size, initializer=orthogonal_initializer()) if dropout is not None: cell = DropoutWrapper(cell, input_keep_prob=dropout) return cell
def get_cell(input_size=None, reuse=False): if encoder.cell_type.lower() == 'lstm': cell = CellWrapper(BasicLSTMCell(encoder.cell_size, reuse=reuse)) elif encoder.cell_type.lower() == 'dropoutgru': cell = DropoutGRUCell(encoder.cell_size, reuse=reuse, layer_norm=encoder.layer_norm, input_size=input_size, input_keep_prob=encoder.rnn_input_keep_prob, state_keep_prob=encoder.rnn_state_keep_prob) else: cell = GRUCell(encoder.cell_size, reuse=reuse, layer_norm=encoder.layer_norm) if encoder.use_dropout and encoder.cell_type.lower() != 'dropoutgru': cell = DropoutWrapper(cell, input_keep_prob=encoder.rnn_input_keep_prob, output_keep_prob=encoder.rnn_output_keep_prob, state_keep_prob=encoder.rnn_state_keep_prob, variational_recurrent=encoder.pervasive_dropout, dtype=tf.float32, input_size=input_size) return cell
def attention_decoder(targets, initial_state, attention_states, encoders, decoder, encoder_input_length, decoder_input_length=None, dropout=None, feed_previous=0.0, feed_argmax=True, **kwargs): """ :param targets: tensor of shape (output_length, batch_size) :param initial_state: initial state of the decoder (usually the final state of the encoder), as a tensor of shape (batch_size, initial_state_size). This state is mapped to the correct state size for the decoder. :param attention_states: list of tensors of shape (batch_size, input_length, encoder_cell_size), usually the encoder outputs (one tensor for each encoder). :param encoders: configuration of the encoders :param decoder: configuration of the decoder :param decoder_input_length: :param dropout: scalar tensor or None, specifying the keep probability (1 - dropout) :param feed_previous: scalar tensor corresponding to the probability to use previous decoder output instead of the groundtruth as input for the decoder (1 when decoding, between 0 and 1 when training) :return: outputs of the decoder as a tensor of shape (batch_size, output_length, decoder_cell_size) attention weights as a tensor of shape (output_length, encoders, batch_size, input_length) """ # TODO: dropout instead of keep probability assert decoder.cell_size % 2 == 0, 'cell size must be a multiple of 2' # because of maxout decoder_inputs = targets[:-1,:] # starts with BOS if decoder.get('embedding') is not None: initializer = decoder.embedding embedding_shape = None else: initializer = None embedding_shape = [decoder.vocab_size, decoder.embedding_size] with tf.device('/cpu:0'): embedding = get_variable_unsafe('embedding_{}'.format(decoder.name), shape=embedding_shape, initializer=initializer) if decoder.use_lstm: cell = BasicLSTMCell(decoder.cell_size, state_is_tuple=False) else: cell = GRUCell(decoder.cell_size, initializer=orthogonal_initializer()) if dropout is not None: cell = DropoutWrapper(cell, input_keep_prob=dropout) if decoder.layers > 1: cell = MultiRNNCell([cell] * decoder.layers, residual_connections=decoder.residual_connections) with tf.variable_scope('decoder_{}'.format(decoder.name)): def embed(input_): if embedding is not None: return tf.nn.embedding_lookup(embedding, input_) else: return input_ hidden_states = [tf.expand_dims(states, 2) for states in attention_states] attention_ = functools.partial(multi_attention, hidden_states=hidden_states, encoders=encoders, encoder_input_length=encoder_input_length) input_shape = tf.shape(decoder_inputs) time_steps = input_shape[0] batch_size = input_shape[1] output_size = decoder.vocab_size state_size = cell.state_size if initial_state is not None: if dropout is not None: initial_state = tf.nn.dropout(initial_state, dropout) state = tf.nn.tanh( linear_unsafe(initial_state, state_size, True, scope='initial_state_projection') ) else: # if not initial state, initialize with zeroes (this is the case for MIXER) state = tf.zeros([batch_size, state_size], dtype=tf.float32) sequence_length = decoder_input_length if sequence_length is not None: sequence_length = tf.to_int32(sequence_length) min_sequence_length = tf.reduce_min(sequence_length) max_sequence_length = tf.reduce_max(sequence_length) time = tf.constant(0, dtype=tf.int32, name='time') zero_output = tf.zeros(tf.stack([batch_size, cell.output_size]), tf.float32) proj_outputs = tf.TensorArray(dtype=tf.float32, size=time_steps, clear_after_read=False) decoder_outputs = tf.TensorArray(dtype=tf.float32, size=time_steps) inputs = tf.TensorArray(dtype=tf.int64, size=time_steps, clear_after_read=False).unstack( tf.cast(decoder_inputs, tf.int64)) samples = tf.TensorArray(dtype=tf.int64, size=time_steps, clear_after_read=False) states = tf.TensorArray(dtype=tf.float32, size=time_steps) attn_lengths = [tf.shape(states)[1] for states in attention_states] weights = tf.TensorArray(dtype=tf.float32, size=time_steps) initial_weights = [tf.zeros(tf.stack([batch_size, length])) for length in attn_lengths] output = tf.zeros(tf.stack([batch_size, cell.output_size]), dtype=tf.float32) initial_input = embed(inputs.read(0)) # first symbol is BOS def _time_step(time, input_, state, output, proj_outputs, decoder_outputs, samples, states, weights, prev_weights): context_vector, new_weights = attention_(state, prev_weights=prev_weights) weights = weights.write(time, new_weights) # FIXME use `output` or `state` here? output_ = linear_unsafe([state, input_, context_vector], decoder.cell_size, False, scope='maxout') output_ = tf.reduce_max(tf.reshape(output_, tf.stack([batch_size, decoder.cell_size // 2, 2])), axis=2) output_ = linear_unsafe(output_, decoder.embedding_size, False, scope='softmax0') decoder_outputs = decoder_outputs.write(time, output_) output_ = linear_unsafe(output_, output_size, True, scope='softmax1') proj_outputs = proj_outputs.write(time, output_) argmax = lambda: tf.argmax(output_, 1) softmax = lambda: tf.squeeze(tf.multinomial(tf.log(tf.nn.softmax(output_)), num_samples=1), axis=1) target = lambda: inputs.read(time + 1) sample = tf.case([ (tf.logical_and(time < time_steps - 1, tf.random_uniform([]) >= feed_previous), target), (tf.logical_not(feed_argmax), softmax)], default=argmax) # default case is useful for beam-search sample.set_shape([None]) sample = tf.stop_gradient(sample) samples = samples.write(time, sample) input_ = embed(sample) x = tf.concat([input_, context_vector], 1) call_cell = lambda: unsafe_decorator(cell)(x, state) if sequence_length is not None: new_output, new_state = rnn._rnn_step( time=time, sequence_length=sequence_length, min_sequence_length=min_sequence_length, max_sequence_length=max_sequence_length, zero_output=zero_output, state=state, call_cell=call_cell, state_size=state_size, skip_conditionals=True) else: new_output, new_state = call_cell() states = states.write(time, new_state) return (time + 1, input_, new_state, new_output, proj_outputs, decoder_outputs, samples, states, weights, new_weights) _, _, new_state, new_output, proj_outputs, decoder_outputs, samples, states, weights, _ = tf.while_loop( cond=lambda time, *_: time < time_steps, body=_time_step, loop_vars=(time, initial_input, state, output, proj_outputs, decoder_outputs, samples, weights, states, initial_weights), parallel_iterations=decoder.parallel_iterations, swap_memory=decoder.swap_memory) proj_outputs = proj_outputs.stack() decoder_outputs = decoder_outputs.stack() samples = samples.stack() weights = weights.stack() # batch_size, encoders, output time, input time states = states.stack() # weights = tf.Print(weights, [weights[:,0]], summarize=20) # tf.control_dependencies() beam_tensors = namedtuple('beam_tensors', 'state new_state output new_output') return (proj_outputs, weights, decoder_outputs, beam_tensors(state, new_state, output, new_output), samples, states)
def multi_encoder(encoder_inputs, encoders, encoder_input_length, dropout=None, **kwargs): """ Build multiple encoders according to the configuration in `encoders`, reading from `encoder_inputs`. The result is a list of the outputs produced by those encoders (for each time-step), and their final state. :param encoder_inputs: list of tensors of shape (batch_size, input_length) (one tensor for each encoder) :param encoders: list of encoder configurations :param encoder_input_length: list of tensors of shape (batch_size) (one tensor for each encoder) :param dropout: scalar tensor or None, specifying the keep probability (1 - dropout) :return: encoder outputs: a list of tensors of shape (batch_size, input_length, encoder_cell_size) encoder state: concatenation of the final states of all encoders, tensor of shape (batch_size, sum_of_state_sizes) """ assert len(encoder_inputs) == len(encoders) encoder_states = [] encoder_outputs = [] # create embeddings in the global scope (allows sharing between encoder and decoder) embedding_variables = [] for encoder in encoders: # inputs are token ids, which need to be mapped to vectors (embeddings) if not encoder.binary: if encoder.get('embedding') is not None: initializer = encoder.embedding embedding_shape = None else: # initializer = tf.random_uniform_initializer(-math.sqrt(3), math.sqrt(3)) initializer = None embedding_shape = [encoder.vocab_size, encoder.embedding_size] with tf.device('/cpu:0'): embedding = get_variable_unsafe('embedding_{}'.format(encoder.name), shape=embedding_shape, initializer=initializer) embedding_variables.append(embedding) else: # do nothing: inputs are already vectors embedding_variables.append(None) for i, encoder in enumerate(encoders): with tf.variable_scope('encoder_{}'.format(encoder.name)): encoder_inputs_ = encoder_inputs[i] encoder_input_length_ = encoder_input_length[i] # TODO: use state_is_tuple=True if encoder.use_lstm: cell = BasicLSTMCell(encoder.cell_size, state_is_tuple=False) else: cell = GRUCell(encoder.cell_size, initializer=orthogonal_initializer()) if dropout is not None: cell = DropoutWrapper(cell, input_keep_prob=dropout) embedding = embedding_variables[i] if embedding is not None or encoder.input_layers: batch_size = tf.shape(encoder_inputs_)[0] # TODO: fix this time major stuff time_steps = tf.shape(encoder_inputs_)[1] if embedding is None: size = encoder_inputs_.get_shape()[2].value flat_inputs = tf.reshape(encoder_inputs_, [tf.multiply(batch_size, time_steps), size]) else: flat_inputs = tf.reshape(encoder_inputs_, [tf.multiply(batch_size, time_steps)]) flat_inputs = tf.nn.embedding_lookup(embedding, flat_inputs) if encoder.input_layers: for j, size in enumerate(encoder.input_layers): name = 'input_layer_{}'.format(j) flat_inputs = tf.nn.tanh(linear_unsafe(flat_inputs, size, bias=True, scope=name)) if dropout is not None: flat_inputs = tf.nn.dropout(flat_inputs, dropout) encoder_inputs_ = tf.reshape(flat_inputs, tf.stack([batch_size, time_steps, flat_inputs.get_shape()[1].value])) # Contrary to Theano's RNN implementation, states after the sequence length are zero # (while Theano repeats last state) sequence_length = encoder_input_length_ # TODO parameters = dict( inputs=encoder_inputs_, sequence_length=sequence_length, time_pooling=encoder.time_pooling, pooling_avg=encoder.pooling_avg, dtype=tf.float32, swap_memory=encoder.swap_memory, parallel_iterations=encoder.parallel_iterations, residual_connections=encoder.residual_connections, trainable_initial_state=True ) if encoder.bidir: encoder_outputs_, _, _ = multi_bidirectional_rnn_unsafe( cells=[(cell, cell)] * encoder.layers, **parameters) # Like Bahdanau et al., we use the first annotation h_1 of the backward encoder encoder_state_ = encoder_outputs_[:, 0, encoder.cell_size:] # TODO: if multiple layers, combine last states with a Maxout layer else: encoder_outputs_, encoder_state_ = multi_rnn_unsafe( cells=[cell] * encoder.layers, **parameters) encoder_state_ = encoder_outputs_[:, -1, :] encoder_outputs.append(encoder_outputs_) encoder_states.append(encoder_state_) encoder_state = tf.concat(encoder_states, 1) return encoder_outputs, encoder_state