Пример #1
0
    def _testStepWithScheduledOutputTrainingHelper(  # pylint:disable=invalid-name
            self, sampling_probability, use_next_inputs_fn,
            use_auxiliary_inputs):
        sequence_length = [3, 4, 3, 1, 0]
        batch_size = 5
        max_time = 8
        input_depth = 7
        cell_depth = input_depth
        if use_auxiliary_inputs:
            auxiliary_input_depth = 4
            auxiliary_inputs = np.random.randn(
                batch_size, max_time, auxiliary_input_depth).astype(np.float32)
        else:
            auxiliary_inputs = None

        with self.session(use_gpu=True) as sess:
            inputs = np.random.randn(batch_size, max_time,
                                     input_depth).astype(np.float32)
            cell = tf.nn.rnn_cell.LSTMCell(cell_depth)
            sampling_probability = tf.constant(sampling_probability)

            if use_next_inputs_fn:

                def next_inputs_fn(outputs):
                    # Use deterministic function for test.
                    samples = tf.argmax(outputs, axis=1)
                    return tf.one_hot(samples, cell_depth, dtype=tf.float32)
            else:
                next_inputs_fn = None

            helper = seq2seq.ScheduledOutputTrainingHelper(
                inputs=inputs,
                sequence_length=sequence_length,
                sampling_probability=sampling_probability,
                time_major=False,
                next_inputs_fn=next_inputs_fn,
                auxiliary_inputs=auxiliary_inputs)

            my_decoder = seq2seq.BasicDecoder(cell=cell,
                                              helper=helper,
                                              initial_state=cell.zero_state(
                                                  dtype=tf.float32,
                                                  batch_size=batch_size))

            output_size = my_decoder.output_size
            output_dtype = my_decoder.output_dtype
            self.assertEqual(
                seq2seq.BasicDecoderOutput(cell_depth, tf.TensorShape([])),
                output_size)
            self.assertEqual(seq2seq.BasicDecoderOutput(tf.float32, tf.int32),
                             output_dtype)

            (first_finished, first_inputs,
             first_state) = my_decoder.initialize()
            (step_outputs, step_state, step_next_inputs,
             step_finished) = my_decoder.step(tf.constant(0), first_inputs,
                                              first_state)

            if use_next_inputs_fn:
                output_after_next_inputs_fn = next_inputs_fn(
                    step_outputs.rnn_output)

            batch_size_t = my_decoder.batch_size

            self.assertIsInstance(first_state, tf.nn.rnn_cell.LSTMStateTuple)
            self.assertIsInstance(step_state, tf.nn.rnn_cell.LSTMStateTuple)
            self.assertIsInstance(step_outputs, seq2seq.BasicDecoderOutput)
            self.assertEqual((batch_size, cell_depth),
                             step_outputs[0].get_shape())
            self.assertEqual((batch_size, ), step_outputs[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             first_state[1].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[0].get_shape())
            self.assertEqual((batch_size, cell_depth),
                             step_state[1].get_shape())

            sess.run(tf.global_variables_initializer())

            fetches = {
                "batch_size": batch_size_t,
                "first_finished": first_finished,
                "first_inputs": first_inputs,
                "first_state": first_state,
                "step_outputs": step_outputs,
                "step_state": step_state,
                "step_next_inputs": step_next_inputs,
                "step_finished": step_finished
            }
            if use_next_inputs_fn:
                fetches[
                    "output_after_next_inputs_fn"] = output_after_next_inputs_fn

            sess_results = sess.run(fetches)

            self.assertAllEqual([False, False, False, False, True],
                                sess_results["first_finished"])
            self.assertAllEqual([False, False, False, True, True],
                                sess_results["step_finished"])

            sample_ids = sess_results["step_outputs"].sample_id
            self.assertEqual(output_dtype.sample_id, sample_ids.dtype)
            batch_where_not_sampling = np.where(np.logical_not(sample_ids))
            batch_where_sampling = np.where(sample_ids)

            auxiliary_inputs_to_concat = (
                auxiliary_inputs[:, 1] if use_auxiliary_inputs else np.array(
                    []).reshape(batch_size, 0).astype(np.float32))

            expected_next_sampling_inputs = np.concatenate(
                (sess_results["output_after_next_inputs_fn"]
                 [batch_where_sampling] if use_next_inputs_fn else
                 sess_results["step_outputs"].rnn_output[batch_where_sampling],
                 auxiliary_inputs_to_concat[batch_where_sampling]),
                axis=-1)
            self.assertAllClose(
                sess_results["step_next_inputs"][batch_where_sampling],
                expected_next_sampling_inputs)

            self.assertAllClose(
                sess_results["step_next_inputs"][batch_where_not_sampling],
                np.concatenate(
                    (np.squeeze(inputs[batch_where_not_sampling, 1], axis=0),
                     auxiliary_inputs_to_concat[batch_where_not_sampling]),
                    axis=-1))
Пример #2
0
  def reconstruction_loss(self, x_input, x_target, x_length, z=None,
                          c_input=None):
    """Reconstruction loss calculation.

    Args:
      x_input: Batch of decoder input sequences for teacher forcing, sized
        `[batch_size, max(x_length), output_depth]`.
      x_target: Batch of expected output sequences to compute loss against,
        sized `[batch_size, max(x_length), output_depth]`.
      x_length: Length of input/output sequences, sized `[batch_size]`.
      z: (Optional) Latent vectors. Required if model is conditional. Sized
        `[n, z_size]`.
      c_input: (Optional) Batch of control sequences, sized
          `[batch_size, max(x_length), control_depth]`. Required if conditioning
          on control sequences.

    Returns:
      r_loss: The reconstruction loss for each sequence in the batch.
      metric_map: Map from metric name to tf.metrics return values for logging.
      decode_results: The LstmDecodeResults.
    """
    batch_size = int(x_input.shape[0])

    has_z = z is not None
    z = tf.zeros([batch_size, 0]) if z is None else z
    repeated_z = tf.tile(
        tf.expand_dims(z, axis=1), [1, tf.shape(x_input)[1], 1])

    has_control = c_input is not None
    if c_input is None:
      c_input = tf.zeros([batch_size, tf.shape(x_input)[1], 0])

    sampling_probability_static = tf.get_static_value(
        self._sampling_probability)
    if sampling_probability_static == 0.0:
      # Use teacher forcing.
      x_input = tf.concat([x_input, repeated_z, c_input], axis=2)
      helper = contrib_seq2seq.TrainingHelper(x_input, x_length)
    else:
      # Use scheduled sampling.
      if has_z or has_control:
        auxiliary_inputs = tf.zeros([batch_size, tf.shape(x_input)[1], 0])
        if has_z:
          auxiliary_inputs = tf.concat([auxiliary_inputs, repeated_z], axis=2)
        if has_control:
          auxiliary_inputs = tf.concat([auxiliary_inputs, c_input], axis=2)
      else:
        auxiliary_inputs = None
      helper = contrib_seq2seq.ScheduledOutputTrainingHelper(
          inputs=x_input,
          sequence_length=x_length,
          auxiliary_inputs=auxiliary_inputs,
          sampling_probability=self._sampling_probability,
          next_inputs_fn=self._sample)

    decode_results = self._decode(
        z, helper=helper, input_shape=helper.inputs.shape[2:])
    flat_x_target = flatten_maybe_padded_sequences(x_target, x_length)
    flat_rnn_output = flatten_maybe_padded_sequences(
        decode_results.rnn_output, x_length)
    r_loss, metric_map = self._flat_reconstruction_loss(
        flat_x_target, flat_rnn_output)

    # Sum loss over sequences.
    cum_x_len = tf.concat([(0,), tf.cumsum(x_length)], axis=0)
    r_losses = []
    for i in range(batch_size):
      b, e = cum_x_len[i], cum_x_len[i + 1]
      r_losses.append(tf.reduce_sum(r_loss[b:e]))
    r_loss = tf.stack(r_losses)

    return r_loss, metric_map, decode_results