def build_graph(self): """Constructs the portion of the graph that belongs to this model.""" tf.logging.info('Initializing melody RNN graph for scope %s', self.scope) with self.graph.as_default(): with tf.device(lambda op: ''): with tf.variable_scope(self.scope): # Make an LSTM cell with the number and size of layers specified in # hparams. if self.note_rnn_type == 'basic_rnn': self.cell = events_rnn_graph.make_rnn_cell( self.hparams.rnn_layer_sizes) else: self.cell = rl_tuner_ops.make_rnn_cell(self.hparams.rnn_layer_sizes) # Shape of melody_sequence is batch size, melody length, number of # output note actions. self.melody_sequence = tf.placeholder(tf.float32, [None, None, self.hparams.one_hot_length], name='melody_sequence') self.lengths = tf.placeholder(tf.int32, [None], name='lengths') self.initial_state = tf.placeholder(tf.float32, [None, self.cell.state_size], name='initial_state') if self.training_file_list is not None: # Set up a tf queue to read melodies from the training data tfrecord (self.train_sequence, self.train_labels, self.train_lengths) = sequence_example_lib.get_padded_batch( self.training_file_list, self.hparams.batch_size, self.hparams.one_hot_length) # Closure function is used so that this part of the graph can be # re-run in multiple places, such as __call__. def run_network_on_melody(m_seq, lens, initial_state, swap_memory=True, parallel_iterations=1): """Internal function that defines the RNN network structure. Args: m_seq: A batch of melody sequences of one-hot notes. lens: Lengths of the melody_sequences. initial_state: Vector representing the initial state of the RNN. swap_memory: Uses more memory and is faster. parallel_iterations: Argument to tf.nn.dynamic_rnn. Returns: Output of network (either softmax or logits) and RNN state. """ outputs, final_state = tf.nn.dynamic_rnn( self.cell, m_seq, sequence_length=lens, initial_state=initial_state, swap_memory=swap_memory, parallel_iterations=parallel_iterations) outputs_flat = tf.reshape(outputs, [-1, self.hparams.rnn_layer_sizes[-1]]) if self.note_rnn_type == 'basic_rnn': linear_layer = tf.contrib.layers.linear else: linear_layer = tf.contrib.layers.legacy_linear logits_flat = linear_layer( outputs_flat, self.hparams.one_hot_length) return logits_flat, final_state (self.logits, self.state_tensor) = run_network_on_melody( self.melody_sequence, self.lengths, self.initial_state) self.softmax = tf.nn.softmax(self.logits) self.run_network_on_melody = run_network_on_melody if self.training_file_list is not None: # Does not recreate the model architecture but rather uses it to feed # data from the training queue through the model. with tf.variable_scope(self.scope, reuse=True): zero_state = self.cell.zero_state( batch_size=self.hparams.batch_size, dtype=tf.float32) (self.train_logits, self.train_state) = run_network_on_melody( self.train_sequence, self.train_lengths, zero_state) self.train_softmax = tf.nn.softmax(self.train_logits)
def build_graph(self): """Constructs the portion of the graph that belongs to this model.""" tf.logging.info('Initializing melody RNN graph for scope %s', self.scope) with self.graph.as_default(): with tf.device(lambda op: ''): with tf.variable_scope(self.scope): # Make an LSTM cell with the number and size of layers specified in # hparams. if self.note_rnn_type == 'basic_rnn': self.cell = events_rnn_graph.make_rnn_cell( self.hparams.rnn_layer_sizes) else: self.cell = rl_tuner_ops.make_rnn_cell( self.hparams.rnn_layer_sizes) # Shape of melody_sequence is batch size, melody length, number of # output note actions. self.melody_sequence = tf.placeholder( tf.float32, [None, None, self.hparams.one_hot_length], name='melody_sequence') self.lengths = tf.placeholder(tf.int32, [None], name='lengths') self.initial_state = tf.placeholder( tf.float32, [None, self.cell.state_size], name='initial_state') if self.training_file_list is not None: # Set up a tf queue to read melodies from the training data tfrecord (self.train_sequence, self.train_labels, self.train_lengths ) = sequence_example_lib.get_padded_batch( self.training_file_list, self.hparams.batch_size, self.hparams.one_hot_length) # Closure function is used so that this part of the graph can be # re-run in multiple places, such as __call__. def run_network_on_melody(m_seq, lens, initial_state, swap_memory=True, parallel_iterations=1): """Internal function that defines the RNN network structure. Args: m_seq: A batch of melody sequences of one-hot notes. lens: Lengths of the melody_sequences. initial_state: Vector representing the initial state of the RNN. swap_memory: Uses more memory and is faster. parallel_iterations: Argument to tf.nn.dynamic_rnn. Returns: Output of network (either softmax or logits) and RNN state. """ outputs, final_state = tf.nn.dynamic_rnn( self.cell, m_seq, sequence_length=lens, initial_state=initial_state, swap_memory=swap_memory, parallel_iterations=parallel_iterations) outputs_flat = tf.reshape( outputs, [-1, self.hparams.rnn_layer_sizes[-1]]) if self.note_rnn_type == 'basic_rnn': linear_layer = tf.contrib.layers.linear else: linear_layer = tf.contrib.layers.legacy_linear logits_flat = linear_layer(outputs_flat, self.hparams.one_hot_length) return logits_flat, final_state (self.logits, self.state_tensor) = run_network_on_melody( self.melody_sequence, self.lengths, self.initial_state) self.softmax = tf.nn.softmax(self.logits) self.run_network_on_melody = run_network_on_melody if self.training_file_list is not None: # Does not recreate the model architecture but rather uses it to feed # data from the training queue through the model. with tf.variable_scope(self.scope, reuse=True): zero_state = self.cell.zero_state( batch_size=self.hparams.batch_size, dtype=tf.float32) (self.train_logits, self.train_state) = run_network_on_melody( self.train_sequence, self.train_lengths, zero_state) self.train_softmax = tf.nn.softmax(self.train_logits)
def build_graph(mode, hparams, encoder_decoder, sequence_example_file=None): """Builds the TensorFlow graph. Args: mode: 'train', 'eval', or 'generate'. Only mode related ops are added to the graph. hparams: A tf_lib.HParams object containing the hyperparameters to use. encoder_decoder: The MelodyEncoderDecoder being used by the model. sequence_example_file: A string path to a TFRecord file containing tf.train.SequenceExamples. Only needed for training and evaluation. Returns: A tf.Graph instance which contains the TF ops. Raises: ValueError: If mode is not 'train', 'eval', or 'generate', or if sequence_example_file does not match a file when mode is 'train' or 'eval'. """ if mode not in ('train', 'eval', 'generate'): raise ValueError('The mode parameter must be \'train\', \'eval\', ' 'or \'generate\'. The mode parameter was: %s' % mode) tf.logging.info('hparams = %s', hparams.values()) input_size = encoder_decoder.input_size num_classes = encoder_decoder.num_classes no_event_label = encoder_decoder.no_event_label with tf.Graph().as_default() as graph: inputs, labels, lengths, = None, None, None state_is_tuple = True if mode == 'train' or mode == 'eval': inputs, labels, lengths = sequence_example_lib.get_padded_batch( [sequence_example_file], hparams.batch_size, input_size) elif mode == 'generate': inputs = tf.placeholder(tf.float32, [hparams.batch_size, None, input_size]) # If state_is_tuple is True, the output RNN cell state will be a tuple # instead of a tensor. During training and evaluation this improves # performance. However, during generation, the RNN cell state is fed # back into the graph with a feed dict. Feed dicts require passed in # values to be tensors and not tuples, so state_is_tuple is set to False. state_is_tuple = False cells = [] for num_units in hparams.rnn_layer_sizes: cell = tf.nn.rnn_cell.BasicLSTMCell(num_units, state_is_tuple=state_is_tuple) cell = tf.nn.rnn_cell.DropoutWrapper( cell, output_keep_prob=hparams.dropout_keep_prob) cells.append(cell) cell = tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=state_is_tuple) if hparams.attn_length: cell = tf.contrib.rnn.AttentionCellWrapper( cell, hparams.attn_length, state_is_tuple=state_is_tuple) initial_state = cell.zero_state(hparams.batch_size, tf.float32) outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, lengths, initial_state, parallel_iterations=1, swap_memory=True) outputs_flat = tf.reshape(outputs, [-1, hparams.rnn_layer_sizes[-1]]) logits_flat = tf.contrib.layers.linear(outputs_flat, num_classes) if mode == 'train' or mode == 'eval': if hparams.skip_first_n_losses: logits = tf.reshape(logits_flat, [hparams.batch_size, -1, num_classes]) logits = logits[:, hparams.skip_first_n_losses:, :] logits_flat = tf.reshape(logits, [-1, num_classes]) labels = labels[:, hparams.skip_first_n_losses:] labels_flat = tf.reshape(labels, [-1]) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits_flat, labels_flat)) perplexity = tf.exp(loss) correct_predictions = tf.to_float( tf.nn.in_top_k(logits_flat, labels_flat, 1)) accuracy = tf.reduce_mean(correct_predictions) * 100 event_positions = tf.to_float( tf.not_equal(labels_flat, no_event_label)) event_accuracy = tf.truediv( tf.reduce_sum(tf.mul(correct_predictions, event_positions)), tf.reduce_sum(event_positions)) * 100 no_event_positions = tf.to_float( tf.equal(labels_flat, no_event_label)) no_event_accuracy = tf.truediv( tf.reduce_sum(tf.mul(correct_predictions, no_event_positions)), tf.reduce_sum(no_event_positions)) * 100 global_step = tf.Variable(0, trainable=False, name='global_step') tf.add_to_collection('loss', loss) tf.add_to_collection('perplexity', perplexity) tf.add_to_collection('accuracy', accuracy) tf.add_to_collection('global_step', global_step) summaries = [ tf.scalar_summary('loss', loss), tf.scalar_summary('perplexity', perplexity), tf.scalar_summary('accuracy', accuracy), tf.scalar_summary('event_accuracy', event_accuracy), tf.scalar_summary('no_event_accuracy', no_event_accuracy), ] if mode == 'train': learning_rate = tf.train.exponential_decay( hparams.initial_learning_rate, global_step, hparams.decay_steps, hparams.decay_rate, staircase=True, name='learning_rate') opt = tf.train.AdamOptimizer(learning_rate) params = tf.trainable_variables() gradients = tf.gradients(loss, params) clipped_gradients, _ = tf.clip_by_global_norm( gradients, hparams.clip_norm) train_op = opt.apply_gradients(zip(clipped_gradients, params), global_step) tf.add_to_collection('learning_rate', learning_rate) tf.add_to_collection('train_op', train_op) summaries.append( tf.scalar_summary('learning_rate', learning_rate)) if mode == 'eval': summary_op = tf.merge_summary(summaries) tf.add_to_collection('summary_op', summary_op) elif mode == 'generate': if hparams.temperature != 1.0: logits_flat /= hparams.temperature softmax_flat = tf.nn.softmax(logits_flat) softmax = tf.reshape(softmax_flat, [hparams.batch_size, -1, num_classes]) tf.add_to_collection('inputs', inputs) tf.add_to_collection('initial_state', initial_state) tf.add_to_collection('final_state', final_state) tf.add_to_collection('softmax', softmax) return graph