Esempio n. 1
0
    def testMaybeSplitSequenceLengths(self):
        with self.test_session():
            # Test unsplit.
            sequence_length = tf.constant([8, 0, 8], tf.int32)
            num_splits = 4
            total_length = 8
            expected_split_length = np.array([[2, 2, 2, 2], [0, 0, 0, 0],
                                              [2, 2, 2, 2]])
            split_length = lstm_utils.maybe_split_sequence_lengths(
                sequence_length, num_splits, total_length).eval()
            self.assertAllEqual(expected_split_length, split_length)

            # Test already split.
            presplit_length = np.array(
                [[0, 2, 1, 2], [0, 0, 0, 0], [1, 1, 1, 1]], np.int32)
            split_length = lstm_utils.maybe_split_sequence_lengths(
                tf.constant(presplit_length), num_splits, total_length).eval()
            self.assertAllEqual(presplit_length, split_length)

            # Test invalid total length.
            with self.assertRaises(tf.errors.InvalidArgumentError):
                sequence_length = tf.constant([8, 0, 7])
                lstm_utils.maybe_split_sequence_lengths(
                    sequence_length, num_splits, total_length).eval()

            # Test invalid segment length.
            with self.assertRaises(tf.errors.InvalidArgumentError):
                presplit_length = np.array(
                    [[0, 2, 3, 1], [0, 0, 0, 0], [1, 1, 1, 1]], np.int32)
                lstm_utils.maybe_split_sequence_lengths(
                    tf.constant(presplit_length), num_splits,
                    total_length).eval()
Esempio n. 2
0
  def testMaybeSplitSequenceLengths(self):
    with self.test_session():
      # Test unsplit.
      sequence_length = tf.constant([8, 0, 8], tf.int32)
      num_splits = 4
      total_length = 8
      expected_split_length = np.array([[2, 2, 2, 2],
                                        [0, 0, 0, 0],
                                        [2, 2, 2, 2]])
      split_length = lstm_utils.maybe_split_sequence_lengths(
          sequence_length, num_splits, total_length).eval()
      self.assertAllEqual(expected_split_length, split_length)

      # Test already split.
      presplit_length = np.array([[0, 2, 1, 2],
                                  [0, 0, 0, 0],
                                  [1, 1, 1, 1]], np.int32)
      split_length = lstm_utils.maybe_split_sequence_lengths(
          tf.constant(presplit_length), num_splits, total_length).eval()
      self.assertAllEqual(presplit_length, split_length)

      # Test invalid total length.
      with self.assertRaises(tf.errors.InvalidArgumentError):
        sequence_length = tf.constant([8, 0, 7])
        lstm_utils.maybe_split_sequence_lengths(
            sequence_length, num_splits, total_length).eval()

      # Test invalid segment length.
      with self.assertRaises(tf.errors.InvalidArgumentError):
        presplit_length = np.array([[0, 2, 3, 1],
                                    [0, 0, 0, 0],
                                    [1, 1, 1, 1]], np.int32)
        lstm_utils.maybe_split_sequence_lengths(
            tf.constant(presplit_length), num_splits, total_length).eval()
Esempio n. 3
0
  def encode(self, sequence, sequence_length):
    """Hierarchically encodes the input sequences, returning a single embedding.

    Each sequence should be padded per-segment. For example, a sequence with
    three segments [1, 2, 3], [4, 5], [6, 7, 8 ,9] and a `max_seq_len` of 12
    should be input as `sequence = [1, 2, 3, 0, 4, 5, 0, 0, 6, 7, 8, 9]` with
    `sequence_length = [3, 2, 4]`.

    Args:
      sequence: A batch of (padded) sequences, sized
        `[batch_size, max_seq_len, input_depth]`.
      sequence_length: A batch of sequence lengths. May be sized
        `[batch_size, level_lengths[0]]` or `[batch_size]`. If the latter,
        each length must either equal `max_seq_len` or 0. In this case, the
        segment lengths are assumed to be constant and the total length will be
        evenly divided amongst the segments.

    Returns:
      embedding: A batch of embeddings, sized `[batch_size, N]`.
    """
    batch_size = sequence.shape[0].value
    sequence_length = lstm_utils.maybe_split_sequence_lengths(
        sequence_length, np.prod(self._level_lengths[1:]),
        self._total_length)

    for level, (num_splits, h_encoder) in enumerate(
        self._hierarchical_encoders):
      split_seqs = tf.split(sequence, num_splits, axis=1)
      # In the first level, we use the input `sequence_lengths`. After that,
      # we use the full embedding sequences.
      sequence_length = (
          sequence_length if level == 0 else
          tf.fill([batch_size, num_splits], split_seqs[0].shape[1]))
      split_lengths = tf.unstack(sequence_length, axis=1)
      embeddings = [
          h_encoder.encode(s, l) for s, l in zip(split_seqs, split_lengths)]
      sequence = tf.stack(embeddings, axis=1)

    with tf.control_dependencies([tf.assert_equal(tf.shape(sequence)[1], 1)]):
      return sequence[:, 0]
Esempio n. 4
0
  def reconstruction_loss(self, x_input, x_target, x_length, z=None,
                          c_input=None):
    """Reconstruction loss calculation.

    Args:
      x_input: Batch of decoder input sequences of concatenated segmeents for
        teacher forcing, sized `[batch_size, max_seq_len, output_depth]`.
      x_target: Batch of expected output sequences to compute loss against,
        sized `[batch_size, max_seq_len, output_depth]`.
      x_length: Length of input/output sequences, sized
        `[batch_size, level_lengths[0]]` or `[batch_size]`. If the latter,
        each length must either equal `max_seq_len` or 0. In this case, the
        segment lengths are assumed to be constant and the total length will be
        evenly divided amongst the segments.
      z: (Optional) Latent vectors. Required if model is conditional. Sized
        `[n, z_size]`.
      c_input: Batch of control sequences. Incompatible with this decoder.

    Returns:
      r_loss: The reconstruction loss for each sequence in the batch.
      metric_map: Map from metric name to tf.metrics return values for logging.
      decode_results: The LstmDecodeResults.

    Raises:
      ValueError: If `c_input` is provided.
    """
    if c_input is not None:
      raise ValueError(
          'Control sequence unsupported in HierarchicalLstmDecoder.')

    batch_size = x_input.shape[0].value

    x_length = lstm_utils.maybe_split_sequence_lengths(
        x_length, np.prod(self._level_lengths[:-1]), self._total_length)

    def _reshape_to_hierarchy(t):
      """Reshapes `t` so that its initial dimensions match the hierarchy."""
      # Exclude the final, core decoder length.
      level_lengths = self._level_lengths[:-1]
      t_shape = t.shape.as_list()
      t_rank = len(t_shape)
      hier_shape = [batch_size] + level_lengths
      if t_rank == 3:
        hier_shape += [-1] + t_shape[2:]
      elif t_rank != 2:
        # We only expect rank-2 for lengths and rank-3 for sequences.
        raise ValueError('Unexpected shape for tensor: %s' % t)
      hier_t = tf.reshape(t, hier_shape)
      # Move the batch dimension to after the hierarchical dimensions.
      num_levels = len(level_lengths)
      perm = range(len(hier_shape))
      perm.insert(num_levels, perm.pop(0))
      return tf.transpose(hier_t, perm)

    hier_input = _reshape_to_hierarchy(x_input)
    hier_target = _reshape_to_hierarchy(x_target)
    hier_length = _reshape_to_hierarchy(x_length)

    loss_outputs = []

    def base_train_fn(embedding, hier_index):
      """Base function for training hierarchical decoder."""
      split_size = self._level_lengths[-1]
      split_input = hier_input[hier_index]
      split_target = hier_target[hier_index]
      split_length = hier_length[hier_index]

      res = self._core_decoder.reconstruction_loss(
          split_input, split_target, split_length, embedding)
      loss_outputs.append(res)
      decode_results = res[-1]

      if self._hierarchical_encoder:
        # Get the approximate "sample" from the model.
        # Start with the inputs the RNN saw (excluding the start token).
        samples = decode_results.rnn_input[:, 1:]
        # Pad to be the max length.
        samples = tf.pad(
            samples,
            [(0, 0), (0, split_size - tf.shape(samples)[1]), (0, 0)])
        samples.set_shape([batch_size, split_size, self._output_depth])
        # Set the final value based on the target, since the scheduled sampling
        # helper does not sample the final value.
        samples = lstm_utils.set_final(
            samples,
            split_length,
            lstm_utils.get_final(split_target, split_length, time_major=False),
            time_major=False)
        # Return the re-encoded sample.
        return self._hierarchical_encoder.level(0).encode(
            sequence=samples,
            sequence_length=split_length)
      elif self._disable_autoregression:
        return None
      else:
        return tf.concat(nest.flatten(decode_results.final_state), axis=-1)

    z = tf.zeros([batch_size, 0]) if z is None else z
    self._hierarchical_decode(z, base_train_fn)

    # Accumulate the split sequence losses.
    r_losses, metric_maps, decode_results = zip(*loss_outputs)

    # Merge the metric maps by passing through renamed values and taking the
    # mean across the splits.
    merged_metric_map = {}
    for metric_name in metric_maps[0]:
      metric_values = []
      for i, m in enumerate(metric_maps):
        merged_metric_map['segment/%03d/%s' % (i, metric_name)] = m[metric_name]
        metric_values.append(m[metric_name][0])
      merged_metric_map[metric_name] = (
          tf.reduce_mean(metric_values), tf.no_op())

    return (tf.reduce_sum(r_losses, axis=0),
            merged_metric_map,
            self._merge_decode_results(decode_results))