def testGetFinal(self): with self.test_session(): sequences = np.arange(40).reshape((4, 5, 2)) lengths = np.array([0, 1, 2, 5]) expected_values = np.array([[0, 1], [10, 11], [22, 23], [38, 39]]) self.assertAllEqual( expected_values, lstm_utils.get_final(sequences, lengths, time_major=False).eval()) self.assertAllEqual( expected_values, lstm_utils.get_final(np.transpose(sequences, [1, 0, 2]), lengths, time_major=True).eval())
def testGetFinal(self): with self.test_session(): sequences = np.arange(40).reshape((4, 5, 2)) lengths = np.array([0, 1, 2, 5]) expected_values = np.array([[0, 1], [10, 11], [22, 23], [38, 39]]) self.assertAllEqual( expected_values, lstm_utils.get_final(sequences, lengths, time_major=False).eval()) self.assertAllEqual( expected_values, lstm_utils.get_final( np.transpose(sequences, [1, 0, 2]), lengths, time_major=True).eval())
def encode(self, sequence, sequence_length): cells_fw, cells_bw = self._cells if self._use_cudnn: # Implements stacked bidirectional LSTM for variable-length sequences, # which are not supported by the CudnnLSTM layer. inputs_fw = tf.transpose(sequence, [1, 0, 2]) for lstm_fw, lstm_bw in zip(cells_fw, cells_bw): outputs_fw, _ = lstm_fw(inputs_fw, training=self._is_training) inputs_bw = tf.reverse_sequence( inputs_fw, sequence_length, seq_axis=0, batch_axis=1) outputs_bw, _ = lstm_bw(inputs_bw, training=self._is_training) outputs_bw = tf.reverse_sequence( outputs_bw, sequence_length, seq_axis=0, batch_axis=1) inputs_fw = tf.concat([outputs_fw, outputs_bw], axis=2) last_h_fw = lstm_utils.get_final(outputs_fw, sequence_length) # outputs_bw has already been reversed, so we can take the first element. last_h_bw = outputs_bw[0] else: _, states_fw, states_bw = rnn.stack_bidirectional_dynamic_rnn( cells_fw, cells_bw, sequence, sequence_length=sequence_length, time_major=False, dtype=tf.float32, scope=self._name_or_scope) # Note we access the outputs (h) from the states since the backward # ouputs are reversed to the input order in the returned outputs. last_h_fw = states_fw[-1][-1].h last_h_bw = states_bw[-1][-1].h return tf.concat([last_h_fw, last_h_bw], 1)
def encode(self, sequence, sequence_length): # Convert to time-major. sequence = tf.transpose(sequence, [1, 0, 2]) if self._use_cudnn: outputs, _ = self._cudnn_lstm( sequence, training=self._is_training) return lstm_utils.get_final(outputs, sequence_length) else: outputs, _ = tf.nn.dynamic_rnn( self._cell, sequence, sequence_length, dtype=tf.float32, time_major=True, scope=self._name_or_scope) return outputs[-1]
def base_train_fn(embedding, hier_index): """Base function for training hierarchical decoder.""" split_size = self._level_lengths[-1] split_input = hier_input[hier_index] split_target = hier_target[hier_index] split_length = hier_length[hier_index] split_control = ( hier_control[hier_index] if hier_control is not None else None) res = self._core_decoder.reconstruction_loss( split_input, split_target, split_length, embedding, split_control) loss_outputs.append(res) decode_results = res[-1] if self._hierarchical_encoder: # Get the approximate "sample" from the model. # Start with the inputs the RNN saw (excluding the start token). samples = decode_results.rnn_input[:, 1:] # Pad to be the max length. samples = tf.pad( samples, [(0, 0), (0, split_size - tf.shape(samples)[1]), (0, 0)]) samples.set_shape([batch_size, split_size, self._output_depth]) # Set the final value based on the target, since the scheduled sampling # helper does not sample the final value. samples = lstm_utils.set_final( samples, split_length, lstm_utils.get_final(split_target, split_length, time_major=False), time_major=False) # Return the re-encoded sample. return self._hierarchical_encoder.level(0).encode( sequence=samples, sequence_length=split_length) elif self._disable_autoregression: return None else: return tf.concat(tf.nest.flatten(decode_results.final_state), axis=-1)