def build_graph(self): """Constructs the portion of the graph that belongs to this model.""" tf.logging.info('Initializing melody RNN graph for scope %s', self.scope) with self.graph.as_default(): with tf.device(lambda op: ''): with tf.variable_scope(self.scope): # Make an LSTM cell with the number and size of layers specified in # hparams. if self.note_rnn_type == 'basic_rnn': self.cell = events_rnn_graph.make_rnn_cell( self.hparams.rnn_layer_sizes) else: self.cell = rl_tuner_ops.make_rnn_cell( self.hparams.rnn_layer_sizes) # Shape of melody_sequence is batch size, melody length, number of # output note actions. self.melody_sequence = tf.placeholder( tf.float32, [None, None, self.hparams.one_hot_length], name='melody_sequence') self.lengths = tf.placeholder(tf.int32, [None], name='lengths') self.initial_state = tf.placeholder( tf.float32, [None, self.cell.state_size], name='initial_state') if self.training_file_list is not None: # Set up a tf queue to read melodies from the training data tfrecord (self.train_sequence, self.train_labels, self.train_lengths ) = sequence_example_lib.get_padded_batch( self.training_file_list, self.hparams.batch_size, self.hparams.one_hot_length) # Closure function is used so that this part of the graph can be # re-run in multiple places, such as __call__. def run_network_on_melody(m_seq, lens, initial_state, swap_memory=True, parallel_iterations=1): """Internal function that defines the RNN network structure. Args: m_seq: A batch of melody sequences of one-hot notes. lens: Lengths of the melody_sequences. initial_state: Vector representing the initial state of the RNN. swap_memory: Uses more memory and is faster. parallel_iterations: Argument to tf.nn.dynamic_rnn. Returns: Output of network (either softmax or logits) and RNN state. """ outputs, final_state = tf.nn.dynamic_rnn( self.cell, m_seq, sequence_length=lens, initial_state=initial_state, swap_memory=swap_memory, parallel_iterations=parallel_iterations) outputs_flat = tf.reshape( outputs, [-1, self.hparams.rnn_layer_sizes[-1]]) if self.note_rnn_type == 'basic_rnn': linear_layer = tf.contrib.layers.linear else: linear_layer = tf.contrib.layers.legacy_linear logits_flat = linear_layer(outputs_flat, self.hparams.one_hot_length) return logits_flat, final_state (self.logits, self.state_tensor) = run_network_on_melody( self.melody_sequence, self.lengths, self.initial_state) self.softmax = tf.nn.softmax(self.logits) self.run_network_on_melody = run_network_on_melody if self.training_file_list is not None: # Does not recreate the model architecture but rather uses it to feed # data from the training queue through the model. with tf.variable_scope(self.scope, reuse=True): zero_state = self.cell.zero_state( batch_size=self.hparams.batch_size, dtype=tf.float32) (self.train_logits, self.train_state) = run_network_on_melody( self.train_sequence, self.train_lengths, zero_state) self.train_softmax = tf.nn.softmax(self.train_logits)
def build_graph(self): """Constructs the portion of the graph that belongs to this model.""" tf.logging.info('Initializing melody RNN graph for scope %s', self.scope) with self.graph.as_default(): with tf.device(lambda op: ''): with tf.variable_scope(self.scope): # Make an LSTM cell with the number and size of layers specified in # hparams. if self.note_rnn_type == 'basic_rnn': self.cell = events_rnn_graph.make_rnn_cell( self.hparams.rnn_layer_sizes) else: self.cell = rl_tuner_ops.make_rnn_cell(self.hparams.rnn_layer_sizes) # Shape of melody_sequence is batch size, melody length, number of # output note actions. self.melody_sequence = tf.placeholder(tf.float32, [None, None, self.hparams.one_hot_length], name='melody_sequence') self.lengths = tf.placeholder(tf.int32, [None], name='lengths') self.initial_state = tf.placeholder(tf.float32, [None, self.cell.state_size], name='initial_state') if self.training_file_list is not None: # Set up a tf queue to read melodies from the training data tfrecord (self.train_sequence, self.train_labels, self.train_lengths) = sequence_example_lib.get_padded_batch( self.training_file_list, self.hparams.batch_size, self.hparams.one_hot_length) # Closure function is used so that this part of the graph can be # re-run in multiple places, such as __call__. def run_network_on_melody(m_seq, lens, initial_state, swap_memory=True, parallel_iterations=1): """Internal function that defines the RNN network structure. Args: m_seq: A batch of melody sequences of one-hot notes. lens: Lengths of the melody_sequences. initial_state: Vector representing the initial state of the RNN. swap_memory: Uses more memory and is faster. parallel_iterations: Argument to tf.nn.dynamic_rnn. Returns: Output of network (either softmax or logits) and RNN state. """ outputs, final_state = tf.nn.dynamic_rnn( self.cell, m_seq, sequence_length=lens, initial_state=initial_state, swap_memory=swap_memory, parallel_iterations=parallel_iterations) outputs_flat = tf.reshape(outputs, [-1, self.hparams.rnn_layer_sizes[-1]]) if self.note_rnn_type == 'basic_rnn': linear_layer = tf.contrib.layers.linear else: linear_layer = tf.contrib.layers.legacy_linear logits_flat = linear_layer( outputs_flat, self.hparams.one_hot_length) return logits_flat, final_state (self.logits, self.state_tensor) = run_network_on_melody( self.melody_sequence, self.lengths, self.initial_state) self.softmax = tf.nn.softmax(self.logits) self.run_network_on_melody = run_network_on_melody if self.training_file_list is not None: # Does not recreate the model architecture but rather uses it to feed # data from the training queue through the model. with tf.variable_scope(self.scope, reuse=True): zero_state = self.cell.zero_state( batch_size=self.hparams.batch_size, dtype=tf.float32) (self.train_logits, self.train_state) = run_network_on_melody( self.train_sequence, self.train_lengths, zero_state) self.train_softmax = tf.nn.softmax(self.train_logits)