def _decode(self, z, helper, input_shape, max_length=None): """Decodes the given batch of latent vectors vectors, which may be 0-length. Args: z: Batch of latent vectors, sized `[batch_size, z_size]`, where `z_size` may be 0 for unconditioned decoding. helper: A seq2seq.Helper to use. input_shape: The shape of each model input vector passed to the decoder. max_length: (Optional) The maximum iterations to decode. Returns: results: The LstmDecodeResults. """ initial_state = lstm_utils.initial_cell_state_from_embedding( self._dec_cell, z, name='decoder/z_to_initial_state') decoder = lstm_utils.Seq2SeqLstmDecoder( self._dec_cell, helper, initial_state=initial_state, input_shape=input_shape, output_layer=self._output_layer) final_output, final_state, final_lengths = contrib_seq2seq.dynamic_decode( decoder, maximum_iterations=max_length, swap_memory=True, scope='decoder') results = lstm_utils.LstmDecodeResults( rnn_input=final_output.rnn_input[:, :, :self._output_depth], rnn_output=final_output.rnn_output, samples=final_output.sample_id, final_state=final_state, final_sequence_lengths=final_lengths) return results
def _testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN( # pylint:disable=invalid-name self, use_sequence_length): sequence_length = [3, 4, 3, 1, 0] batch_size = 5 max_time = 8 input_depth = 7 cell_depth = 10 max_out = max(sequence_length) with self.session(use_gpu=True) as sess: inputs = np.random.randn(batch_size, max_time, input_depth).astype(np.float32) cell = tf.nn.rnn_cell.LSTMCell(cell_depth) zero_state = cell.zero_state(dtype=tf.float32, batch_size=batch_size) helper = seq2seq.TrainingHelper(inputs, sequence_length) my_decoder = seq2seq.BasicDecoder(cell=cell, helper=helper, initial_state=zero_state) # Match the variable scope of dynamic_rnn below so we end up # using the same variables with tf.variable_scope("root") as scope: final_decoder_outputs, final_decoder_state, _ = seq2seq.dynamic_decode( my_decoder, # impute_finished=True ensures outputs and final state # match those of dynamic_rnn called with sequence_length not None impute_finished=use_sequence_length, scope=scope) with tf.variable_scope(scope, reuse=True) as scope: final_rnn_outputs, final_rnn_state = tf.nn.dynamic_rnn( cell, inputs, sequence_length=sequence_length if use_sequence_length else None, initial_state=zero_state, scope=scope) sess.run(tf.global_variables_initializer()) sess_results = sess.run({ "final_decoder_outputs": final_decoder_outputs, "final_decoder_state": final_decoder_state, "final_rnn_outputs": final_rnn_outputs, "final_rnn_state": final_rnn_state }) # Decoder only runs out to max_out; ensure values are identical # to dynamic_rnn, which also zeros out outputs and passes along state. self.assertAllClose( sess_results["final_decoder_outputs"].rnn_output, sess_results["final_rnn_outputs"][:, 0:max_out, :]) if use_sequence_length: self.assertAllClose(sess_results["final_decoder_state"], sess_results["final_rnn_state"])
def _get_state(self, inputs, lengths=None, initial_state=None): """Computes the state of the RNN-NADE (NADE bias parameters and RNN state). Args: inputs: A batch of sequences to compute the state from, sized `[batch_size, max(lengths), num_dims]` or `[batch_size, num_dims]`. lengths: The length of each sequence, sized `[batch_size]`. initial_state: An RnnNadeStateTuple, the initial state of the RNN-NADE, or None if the zero state should be used. Returns: final_state: An RnnNadeStateTuple, the final state of the RNN-NADE. """ batch_size = int(inputs.shape[0]) if lengths is None: lengths = tf.tile(tf.shape(inputs)[1:2], [batch_size]) if initial_state is None: initial_rnn_state = self._get_rnn_zero_state(batch_size) else: initial_rnn_state = initial_state.rnn_state helper = contrib_seq2seq.TrainingHelper( inputs=inputs, sequence_length=lengths) decoder = contrib_seq2seq.BasicDecoder( cell=self._rnn_cell, helper=helper, initial_state=initial_rnn_state, output_layer=self._fc_layer) final_outputs, final_rnn_state = contrib_seq2seq.dynamic_decode( decoder)[0:2] # Flatten time dimension. final_outputs_flat = magenta.common.flatten_maybe_padded_sequences( final_outputs.rnn_output, lengths) b_enc, b_dec = tf.split( final_outputs_flat, [self._nade.num_hidden, self._nade.num_dims], axis=1) return RnnNadeStateTuple(b_enc, b_dec, final_rnn_state)
def _testDynamicDecodeRNN(self, time_major, maximum_iterations=None): # pylint:disable=invalid-name sequence_length = [3, 4, 3, 1, 0] batch_size = 5 max_time = 8 input_depth = 7 cell_depth = 10 max_out = max(sequence_length) with self.session(use_gpu=True) as sess: if time_major: inputs = np.random.randn(max_time, batch_size, input_depth).astype(np.float32) else: inputs = np.random.randn(batch_size, max_time, input_depth).astype(np.float32) cell = tf.nn.rnn_cell.LSTMCell(cell_depth) helper = seq2seq.TrainingHelper(inputs, sequence_length, time_major=time_major) my_decoder = seq2seq.BasicDecoder(cell=cell, helper=helper, initial_state=cell.zero_state( dtype=tf.float32, batch_size=batch_size)) final_outputs, final_state, final_sequence_length = ( seq2seq.dynamic_decode(my_decoder, output_time_major=time_major, maximum_iterations=maximum_iterations)) def _t(shape): if time_major: return (shape[1], shape[0]) + shape[2:] return shape self.assertIsInstance(final_outputs, seq2seq.BasicDecoderOutput) self.assertIsInstance(final_state, tf.nn.rnn_cell.LSTMStateTuple) self.assertEqual( (batch_size, ), tuple(final_sequence_length.get_shape().as_list())) self.assertEqual( _t((batch_size, None, cell_depth)), tuple(final_outputs.rnn_output.get_shape().as_list())) self.assertEqual( _t((batch_size, None)), tuple(final_outputs.sample_id.get_shape().as_list())) sess.run(tf.global_variables_initializer()) sess_results = sess.run({ "final_outputs": final_outputs, "final_state": final_state, "final_sequence_length": final_sequence_length, }) # Mostly a smoke test time_steps = max_out expected_length = sequence_length if maximum_iterations is not None: time_steps = min(max_out, maximum_iterations) expected_length = [ min(x, maximum_iterations) for x in expected_length ] self.assertEqual(_t((batch_size, time_steps, cell_depth)), sess_results["final_outputs"].rnn_output.shape) self.assertEqual(_t((batch_size, time_steps)), sess_results["final_outputs"].sample_id.shape) self.assertCountEqual(expected_length, sess_results["final_sequence_length"])