def testMaybeSplitSequenceLengths(self): with self.test_session(): # Test unsplit. sequence_length = tf.constant([8, 0, 8], tf.int32) num_splits = 4 total_length = 8 expected_split_length = np.array([[2, 2, 2, 2], [0, 0, 0, 0], [2, 2, 2, 2]]) split_length = lstm_utils.maybe_split_sequence_lengths( sequence_length, num_splits, total_length).eval() self.assertAllEqual(expected_split_length, split_length) # Test already split. presplit_length = np.array( [[0, 2, 1, 2], [0, 0, 0, 0], [1, 1, 1, 1]], np.int32) split_length = lstm_utils.maybe_split_sequence_lengths( tf.constant(presplit_length), num_splits, total_length).eval() self.assertAllEqual(presplit_length, split_length) # Test invalid total length. with self.assertRaises(tf.errors.InvalidArgumentError): sequence_length = tf.constant([8, 0, 7]) lstm_utils.maybe_split_sequence_lengths( sequence_length, num_splits, total_length).eval() # Test invalid segment length. with self.assertRaises(tf.errors.InvalidArgumentError): presplit_length = np.array( [[0, 2, 3, 1], [0, 0, 0, 0], [1, 1, 1, 1]], np.int32) lstm_utils.maybe_split_sequence_lengths( tf.constant(presplit_length), num_splits, total_length).eval()
def testMaybeSplitSequenceLengths(self): with self.test_session(): # Test unsplit. sequence_length = tf.constant([8, 0, 8], tf.int32) num_splits = 4 total_length = 8 expected_split_length = np.array([[2, 2, 2, 2], [0, 0, 0, 0], [2, 2, 2, 2]]) split_length = lstm_utils.maybe_split_sequence_lengths( sequence_length, num_splits, total_length).eval() self.assertAllEqual(expected_split_length, split_length) # Test already split. presplit_length = np.array([[0, 2, 1, 2], [0, 0, 0, 0], [1, 1, 1, 1]], np.int32) split_length = lstm_utils.maybe_split_sequence_lengths( tf.constant(presplit_length), num_splits, total_length).eval() self.assertAllEqual(presplit_length, split_length) # Test invalid total length. with self.assertRaises(tf.errors.InvalidArgumentError): sequence_length = tf.constant([8, 0, 7]) lstm_utils.maybe_split_sequence_lengths( sequence_length, num_splits, total_length).eval() # Test invalid segment length. with self.assertRaises(tf.errors.InvalidArgumentError): presplit_length = np.array([[0, 2, 3, 1], [0, 0, 0, 0], [1, 1, 1, 1]], np.int32) lstm_utils.maybe_split_sequence_lengths( tf.constant(presplit_length), num_splits, total_length).eval()
def encode(self, sequence, sequence_length): """Hierarchically encodes the input sequences, returning a single embedding. Each sequence should be padded per-segment. For example, a sequence with three segments [1, 2, 3], [4, 5], [6, 7, 8 ,9] and a `max_seq_len` of 12 should be input as `sequence = [1, 2, 3, 0, 4, 5, 0, 0, 6, 7, 8, 9]` with `sequence_length = [3, 2, 4]`. Args: sequence: A batch of (padded) sequences, sized `[batch_size, max_seq_len, input_depth]`. sequence_length: A batch of sequence lengths. May be sized `[batch_size, level_lengths[0]]` or `[batch_size]`. If the latter, each length must either equal `max_seq_len` or 0. In this case, the segment lengths are assumed to be constant and the total length will be evenly divided amongst the segments. Returns: embedding: A batch of embeddings, sized `[batch_size, N]`. """ batch_size = sequence.shape[0].value sequence_length = lstm_utils.maybe_split_sequence_lengths( sequence_length, np.prod(self._level_lengths[1:]), self._total_length) for level, (num_splits, h_encoder) in enumerate( self._hierarchical_encoders): split_seqs = tf.split(sequence, num_splits, axis=1) # In the first level, we use the input `sequence_lengths`. After that, # we use the full embedding sequences. sequence_length = ( sequence_length if level == 0 else tf.fill([batch_size, num_splits], split_seqs[0].shape[1])) split_lengths = tf.unstack(sequence_length, axis=1) embeddings = [ h_encoder.encode(s, l) for s, l in zip(split_seqs, split_lengths)] sequence = tf.stack(embeddings, axis=1) with tf.control_dependencies([tf.assert_equal(tf.shape(sequence)[1], 1)]): return sequence[:, 0]
def reconstruction_loss(self, x_input, x_target, x_length, z=None, c_input=None): """Reconstruction loss calculation. Args: x_input: Batch of decoder input sequences of concatenated segmeents for teacher forcing, sized `[batch_size, max_seq_len, output_depth]`. x_target: Batch of expected output sequences to compute loss against, sized `[batch_size, max_seq_len, output_depth]`. x_length: Length of input/output sequences, sized `[batch_size, level_lengths[0]]` or `[batch_size]`. If the latter, each length must either equal `max_seq_len` or 0. In this case, the segment lengths are assumed to be constant and the total length will be evenly divided amongst the segments. z: (Optional) Latent vectors. Required if model is conditional. Sized `[n, z_size]`. c_input: Batch of control sequences. Incompatible with this decoder. Returns: r_loss: The reconstruction loss for each sequence in the batch. metric_map: Map from metric name to tf.metrics return values for logging. decode_results: The LstmDecodeResults. Raises: ValueError: If `c_input` is provided. """ if c_input is not None: raise ValueError( 'Control sequence unsupported in HierarchicalLstmDecoder.') batch_size = x_input.shape[0].value x_length = lstm_utils.maybe_split_sequence_lengths( x_length, np.prod(self._level_lengths[:-1]), self._total_length) def _reshape_to_hierarchy(t): """Reshapes `t` so that its initial dimensions match the hierarchy.""" # Exclude the final, core decoder length. level_lengths = self._level_lengths[:-1] t_shape = t.shape.as_list() t_rank = len(t_shape) hier_shape = [batch_size] + level_lengths if t_rank == 3: hier_shape += [-1] + t_shape[2:] elif t_rank != 2: # We only expect rank-2 for lengths and rank-3 for sequences. raise ValueError('Unexpected shape for tensor: %s' % t) hier_t = tf.reshape(t, hier_shape) # Move the batch dimension to after the hierarchical dimensions. num_levels = len(level_lengths) perm = range(len(hier_shape)) perm.insert(num_levels, perm.pop(0)) return tf.transpose(hier_t, perm) hier_input = _reshape_to_hierarchy(x_input) hier_target = _reshape_to_hierarchy(x_target) hier_length = _reshape_to_hierarchy(x_length) loss_outputs = [] def base_train_fn(embedding, hier_index): """Base function for training hierarchical decoder.""" split_size = self._level_lengths[-1] split_input = hier_input[hier_index] split_target = hier_target[hier_index] split_length = hier_length[hier_index] res = self._core_decoder.reconstruction_loss( split_input, split_target, split_length, embedding) loss_outputs.append(res) decode_results = res[-1] if self._hierarchical_encoder: # Get the approximate "sample" from the model. # Start with the inputs the RNN saw (excluding the start token). samples = decode_results.rnn_input[:, 1:] # Pad to be the max length. samples = tf.pad( samples, [(0, 0), (0, split_size - tf.shape(samples)[1]), (0, 0)]) samples.set_shape([batch_size, split_size, self._output_depth]) # Set the final value based on the target, since the scheduled sampling # helper does not sample the final value. samples = lstm_utils.set_final( samples, split_length, lstm_utils.get_final(split_target, split_length, time_major=False), time_major=False) # Return the re-encoded sample. return self._hierarchical_encoder.level(0).encode( sequence=samples, sequence_length=split_length) elif self._disable_autoregression: return None else: return tf.concat(nest.flatten(decode_results.final_state), axis=-1) z = tf.zeros([batch_size, 0]) if z is None else z self._hierarchical_decode(z, base_train_fn) # Accumulate the split sequence losses. r_losses, metric_maps, decode_results = zip(*loss_outputs) # Merge the metric maps by passing through renamed values and taking the # mean across the splits. merged_metric_map = {} for metric_name in metric_maps[0]: metric_values = [] for i, m in enumerate(metric_maps): merged_metric_map['segment/%03d/%s' % (i, metric_name)] = m[metric_name] metric_values.append(m[metric_name][0]) merged_metric_map[metric_name] = ( tf.reduce_mean(metric_values), tf.no_op()) return (tf.reduce_sum(r_losses, axis=0), merged_metric_map, self._merge_decode_results(decode_results))