Exemple #1
0
    def reconstruction_loss(self, x_input, x_target, x_length, z=None):
        """Reconstruction loss calculation.

    Args:
      x_input: Batch of decoder input sequences for teacher forcing, sized
        `[batch_size, max(x_length), output_depth]`.
      x_target: Batch of expected output sequences to compute loss against,
        sized `[batch_size, max(x_length), output_depth]`.
      x_length: Length of input/output sequences, sized `[batch_size]`.
      z: (Optional) Latent vectors. Required if model is conditional. Sized
        `[n, z_size]`.

    Returns:
      r_loss: The reconstruction loss for each sequence in the batch.
      metric_map: Map from metric name to tf.metrics return values for logging.
      truths: Ground truth labels.
      predictions: Predicted labels.
      final_state: The final states of the decoder, or None if using Cudnn.
    """
        batch_size = x_input.shape[0].value

        has_z = z is not None
        z = tf.zeros([batch_size, 0]) if z is None else z
        repeated_z = tf.tile(tf.expand_dims(z, axis=1),
                             [1, tf.shape(x_input)[1], 1])

        sampling_probability_static = tensor_util.constant_value(
            self._sampling_probability)
        if sampling_probability_static == 0.0:
            # Use teacher forcing.
            x_input = tf.concat([x_input, repeated_z], axis=2)
            helper = seq2seq.TrainingHelper(x_input, x_length)
        else:
            # Use scheduled sampling.
            helper = seq2seq.ScheduledOutputTrainingHelper(
                inputs=x_input,
                sequence_length=x_length,
                auxiliary_inputs=repeated_z if has_z else None,
                sampling_probability=self._sampling_probability,
                next_inputs_fn=self._sample)

        decoder_outputs, final_state = self._decode(z,
                                                    helper=helper,
                                                    x_input=x_input)
        flat_x_target = flatten_maybe_padded_sequences(x_target, x_length)
        flat_rnn_output = flatten_maybe_padded_sequences(
            decoder_outputs.rnn_output, x_length)
        r_loss, metric_map, truths, predictions = self._flat_reconstruction_loss(
            flat_x_target, flat_rnn_output)

        # Sum loss over sequences.
        cum_x_len = tf.concat([(0, ), tf.cumsum(x_length)], axis=0)
        r_losses = []
        for i in range(batch_size):
            b, e = cum_x_len[i], cum_x_len[i + 1]
            r_losses.append(tf.reduce_sum(r_loss[b:e]))
        r_loss = tf.stack(r_losses)

        return r_loss, metric_map, truths, predictions, final_state
Exemple #2
0
    def _build_decoder(self, encoder_states, target_sequence, keep_prob,
                       sampling_prob, attention_mechanism):
        """Define decoder architecture.
        """
        # connect each layer sequentially, building a graph that resembles a
        # feed-forward network made of recurrent units
        decoder_cell = self._multi_cell(num_units=self.num_units,
                                        num_layers=self.num_layers,
                                        keep_prob=keep_prob)

        # connect attention to decoder
        attention_layer_size = self.num_units
        decoder = seq2seq.AttentionWrapper(
            cell=decoder_cell,
            attention_mechanism=attention_mechanism,
            attention_layer_size=attention_layer_size)

        # decoder start symbol
        decoder_raw_seq = target_sequence[:, :-1]
        prefix = tf.fill([tf.shape(target_sequence)[0], 1, self.target_depth],
                         0.0)
        decoder_input_seq = tf.concat([prefix, decoder_raw_seq], axis=1)

        # the model is using fixed lengths of target sequences so tile the defined
        # length in the batch dimension
        decoder_sequence_length = tf.tile([self.target_length],
                                          [tf.shape(target_sequence)[0]])

        # decoder sampling scheduler feeds decoder output to next time input
        # instead of using ground-truth target vals during training
        helper = seq2seq.ScheduledOutputTrainingHelper(
            inputs=decoder_input_seq,
            sequence_length=decoder_sequence_length,
            sampling_probability=sampling_prob)

        # output layer
        projection_layer = Dense(units=self.target_depth, use_bias=True)

        # clone encoder state
        initial_state = decoder.zero_state(
            tf.shape(target_sequence)[0], tf.float32)
        initial_state = initial_state.clone(cell_state=encoder_states)

        # wrapper for decoder
        decoder = seq2seq.BasicDecoder(cell=decoder,
                                       helper=helper,
                                       initial_state=initial_state,
                                       output_layer=projection_layer)

        # build the unrolled graph of the recurrent neural network
        outputs, decoder_state, _sequence_lengths = seq2seq.dynamic_decode(
            decoder=decoder, maximum_iterations=self.target_length)

        return (outputs, decoder_state)
Exemple #3
0
    def build_predict_op(self):
        with tf.variable_scope('predict'):
            decoder_cell = self.decoder_cell
            targets = self.targets
            sequence_lengths = self.training_seq_lens
            predict_helper = seq2seq.ScheduledOutputTrainingHelper(
                targets,
                sequence_lengths,
                sampling_probability=1.0,
                next_input_layer=self.projection_layer)
            decoder = seq2seq.BasicDecoder(decoder_cell, predict_helper,
                                           self.get_zero_state())
            output, _, _ = seq2seq.dynamic_decode(decoder,
                                                  output_time_major=True)

            self.predictions = self.projection_layer(output.rnn_output)
Exemple #4
0
 def build_train_op(self):
     with tf.variable_scope('training'):
         decoder_cell = self.decoder_cell
         targets = self.targets
         sequence_lengths = self.training_seq_lens
         training_helper = seq2seq.ScheduledOutputTrainingHelper(
             targets,
             sequence_lengths,
             sampling_probability=self.sampling_rate,
             next_input_layer=self.projection_layer)
         decoder = seq2seq.BasicDecoder(decoder_cell,
                                        helper=training_helper,
                                        initial_state=self.get_zero_state())
         output, _, _ = seq2seq.dynamic_decode(decoder,
                                               output_time_major=True)
         predictions = self.projection_layer.apply(output.rnn_output)
     time_major = tf.transpose(targets, perm=[1, 0, 2])
     x_entropy = tf.nn.softmax_cross_entropy_with_logits(
         labels=time_major, logits=predictions, name='CrossEntropy')
     loss = tf.reduce_mean(x_entropy)
     self.train_op = tf.contrib.layers.optimize_loss(
         loss, tf.contrib.framework.get_global_step(), 0.001, 'Adam')