Beispiel #1
0
    def build_graph(self):
        """Constructs the portion of the graph that belongs to this model."""

        tf.logging.info('Initializing melody RNN graph for scope %s',
                        self.scope)

        with self.graph.as_default():
            with tf.device(lambda op: ''):
                with tf.variable_scope(self.scope):
                    # Make an LSTM cell with the number and size of layers specified in
                    # hparams.
                    if self.note_rnn_type == 'basic_rnn':
                        self.cell = events_rnn_graph.make_rnn_cell(
                            self.hparams.rnn_layer_sizes)
                    else:
                        self.cell = rl_tuner_ops.make_rnn_cell(
                            self.hparams.rnn_layer_sizes)
                    # Shape of melody_sequence is batch size, melody length, number of
                    # output note actions.
                    self.melody_sequence = tf.placeholder(
                        tf.float32, [None, None, self.hparams.one_hot_length],
                        name='melody_sequence')
                    self.lengths = tf.placeholder(tf.int32, [None],
                                                  name='lengths')
                    self.initial_state = tf.placeholder(
                        tf.float32, [None, self.cell.state_size],
                        name='initial_state')

                    if self.training_file_list is not None:
                        # Set up a tf queue to read melodies from the training data tfrecord
                        (self.train_sequence, self.train_labels,
                         self.train_lengths
                         ) = sequence_example_lib.get_padded_batch(
                             self.training_file_list, self.hparams.batch_size,
                             self.hparams.one_hot_length)

                    # Closure function is used so that this part of the graph can be
                    # re-run in multiple places, such as __call__.
                    def run_network_on_melody(m_seq,
                                              lens,
                                              initial_state,
                                              swap_memory=True,
                                              parallel_iterations=1):
                        """Internal function that defines the RNN network structure.

            Args:
              m_seq: A batch of melody sequences of one-hot notes.
              lens: Lengths of the melody_sequences.
              initial_state: Vector representing the initial state of the RNN.
              swap_memory: Uses more memory and is faster.
              parallel_iterations: Argument to tf.nn.dynamic_rnn.
            Returns:
              Output of network (either softmax or logits) and RNN state.
            """
                        outputs, final_state = tf.nn.dynamic_rnn(
                            self.cell,
                            m_seq,
                            sequence_length=lens,
                            initial_state=initial_state,
                            swap_memory=swap_memory,
                            parallel_iterations=parallel_iterations)

                        outputs_flat = tf.reshape(
                            outputs, [-1, self.hparams.rnn_layer_sizes[-1]])
                        if self.note_rnn_type == 'basic_rnn':
                            linear_layer = tf.contrib.layers.linear
                        else:
                            linear_layer = tf.contrib.layers.legacy_linear
                        logits_flat = linear_layer(outputs_flat,
                                                   self.hparams.one_hot_length)
                        return logits_flat, final_state

                    (self.logits, self.state_tensor) = run_network_on_melody(
                        self.melody_sequence, self.lengths, self.initial_state)
                    self.softmax = tf.nn.softmax(self.logits)

                    self.run_network_on_melody = run_network_on_melody

                if self.training_file_list is not None:
                    # Does not recreate the model architecture but rather uses it to feed
                    # data from the training queue through the model.
                    with tf.variable_scope(self.scope, reuse=True):
                        zero_state = self.cell.zero_state(
                            batch_size=self.hparams.batch_size,
                            dtype=tf.float32)

                        (self.train_logits,
                         self.train_state) = run_network_on_melody(
                             self.train_sequence, self.train_lengths,
                             zero_state)
                        self.train_softmax = tf.nn.softmax(self.train_logits)
def build_graph(mode, config, sequence_example_file_paths=None):
  """Builds the TensorFlow graph.

  Args:
    mode: 'train', 'eval', or 'generate'. Only mode related ops are added to
        the graph.
    config: An EventSequenceRnnConfig containing the encoder/decoder and HParams
        to use.
    sequence_example_file_paths: A list of paths to TFRecord files containing
        tf.train.SequenceExample protos. Only needed for training and
        evaluation. May be a sharded file of the form.

  Returns:
    A tf.Graph instance which contains the TF ops.

  Raises:
    ValueError: If mode is not 'train', 'eval', or 'generate'.
  """
  if mode not in ('train', 'eval', 'generate'):
    raise ValueError("The mode parameter must be 'train', 'eval', "
                     "or 'generate'. The mode parameter was: %s" % mode)

  hparams = config.hparams
  encoder_decoder = config.encoder_decoder

  tf.logging.info('hparams = %s', hparams.values())

  input_size = encoder_decoder.input_size

  with tf.Graph().as_default() as graph:
    inputs, lengths = None, None

    if mode == 'train' or mode == 'eval':
      inputs, _, lengths = magenta.common.get_padded_batch(
          sequence_example_file_paths, hparams.batch_size, input_size,
          shuffle=mode == 'train')

    elif mode == 'generate':
      inputs = tf.placeholder(tf.float32,
                              [hparams.batch_size, None, input_size])

    cell = events_rnn_graph.make_rnn_cell(
        hparams.rnn_layer_sizes,
        dropout_keep_prob=hparams.dropout_keep_prob if mode == 'train' else 1.0,
        attn_length=(
            hparams.attn_length if hasattr(hparams, 'attn_length') else 0))

    rnn_nade = RnnNade(
        cell,
        num_dims=input_size,
        num_hidden=hparams.nade_hidden_units)

    if mode == 'train' or mode == 'eval':
      log_probs, cond_probs = rnn_nade.log_prob(inputs, lengths)

      inputs_flat = tf.to_float(
          magenta.common.flatten_maybe_padded_sequences(inputs, lengths))
      predictions_flat = tf.to_float(tf.greater_equal(cond_probs, .5))

      if mode == 'train':
        loss = tf.reduce_mean(-log_probs)
        perplexity = tf.reduce_mean(tf.exp(log_probs))
        correct_predictions = tf.to_float(
            tf.equal(inputs_flat, predictions_flat))
        accuracy = tf.reduce_mean(correct_predictions)
        precision = (tf.reduce_sum(inputs_flat * predictions_flat) /
                     tf.reduce_sum(predictions_flat))
        recall = (tf.reduce_sum(inputs_flat * predictions_flat) /
                  tf.reduce_sum(inputs_flat))

        optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)

        train_op = tf.contrib.slim.learning.create_train_op(
            loss, optimizer, clip_gradient_norm=hparams.clip_norm)
        tf.add_to_collection('train_op', train_op)

        vars_to_summarize = {
            'loss': loss,
            'metrics/perplexity': perplexity,
            'metrics/accuracy': accuracy,
            'metrics/precision': precision,
            'metrics/recall': recall,
        }
      elif mode == 'eval':
        vars_to_summarize, update_ops = tf.contrib.metrics.aggregate_metric_map(
            {
                'loss': tf.metrics.mean(-log_probs),
                'metrics/perplexity': tf.metrics.mean(tf.exp(log_probs)),
                'metrics/accuracy': tf.metrics.accuracy(
                    inputs_flat, predictions_flat),
                'metrics/precision': tf.metrics.precision(
                    inputs_flat, predictions_flat),
                'metrics/recall': tf.metrics.recall(
                    inputs_flat, predictions_flat),
            })
        for updates_op in update_ops.values():
          tf.add_to_collection('eval_ops', updates_op)

      precision = vars_to_summarize['metrics/precision']
      recall = vars_to_summarize['metrics/precision']
      f1_score = tf.where(
          tf.greater(precision + recall, 0), 2 * (
              (precision * recall) / (precision + recall)), 0)
      vars_to_summarize['metrics/f1_score'] = f1_score
      for var_name, var_value in vars_to_summarize.iteritems():
        tf.summary.scalar(var_name, var_value)
        tf.add_to_collection(var_name, var_value)

    elif mode == 'generate':
      initial_state = rnn_nade.zero_state(hparams.batch_size)

      final_state = rnn_nade.steps(inputs, initial_state)
      samples, log_prob = rnn_nade.sample_single(initial_state)

      tf.add_to_collection('inputs', inputs)
      tf.add_to_collection('sample', samples)
      tf.add_to_collection('log_prob', log_prob)

      # Flatten state tuples for metagraph compatibility.
      for state in tf_nest.flatten(initial_state):
        tf.add_to_collection('initial_state', state)
      for state in tf_nest.flatten(final_state):
        tf.add_to_collection('final_state', state)

  return graph
Beispiel #3
0
    def build():
        """Builds the Tensorflow graph."""
        inputs, lengths = None, None

        if mode in ('train', 'eval'):
            inputs, _, lengths = magenta.common.get_padded_batch(
                sequence_example_file_paths,
                hparams.batch_size,
                input_size,
                shuffle=mode == 'train')

        elif mode == 'generate':
            inputs = tf.placeholder(tf.float32,
                                    [hparams.batch_size, None, input_size])

        cell = events_rnn_graph.make_rnn_cell(
            hparams.rnn_layer_sizes,
            dropout_keep_prob=hparams.dropout_keep_prob
            if mode == 'train' else 1.0,
            attn_length=hparams.attn_length,
            residual_connections=hparams.residual_connections)

        rnn_nade = RnnNade(cell,
                           num_dims=input_size,
                           num_hidden=hparams.nade_hidden_units)

        if mode in ('train', 'eval'):
            log_probs, cond_probs = rnn_nade.log_prob(inputs, lengths)

            inputs_flat = tf.to_float(
                magenta.common.flatten_maybe_padded_sequences(inputs, lengths))
            predictions_flat = tf.to_float(tf.greater_equal(cond_probs, .5))

            if mode == 'train':
                loss = tf.reduce_mean(-log_probs)
                perplexity = tf.reduce_mean(tf.exp(log_probs))
                correct_predictions = tf.to_float(
                    tf.equal(inputs_flat, predictions_flat))
                accuracy = tf.reduce_mean(correct_predictions)
                precision = (tf.reduce_sum(inputs_flat * predictions_flat) /
                             tf.reduce_sum(predictions_flat))
                recall = (tf.reduce_sum(inputs_flat * predictions_flat) /
                          tf.reduce_sum(inputs_flat))

                optimizer = tf.train.AdamOptimizer(
                    learning_rate=hparams.learning_rate)

                train_op = tf.contrib.slim.learning.create_train_op(
                    loss, optimizer, clip_gradient_norm=hparams.clip_norm)
                tf.add_to_collection('train_op', train_op)

                vars_to_summarize = {
                    'loss': loss,
                    'metrics/perplexity': perplexity,
                    'metrics/accuracy': accuracy,
                    'metrics/precision': precision,
                    'metrics/recall': recall,
                }
            elif mode == 'eval':
                vars_to_summarize, update_ops = tf.contrib.metrics.aggregate_metric_map(
                    {
                        'loss':
                        tf.metrics.mean(-log_probs),
                        'metrics/perplexity':
                        tf.metrics.mean(tf.exp(log_probs)),
                        'metrics/accuracy':
                        tf.metrics.accuracy(inputs_flat, predictions_flat),
                        'metrics/precision':
                        tf.metrics.precision(inputs_flat, predictions_flat),
                        'metrics/recall':
                        tf.metrics.recall(inputs_flat, predictions_flat),
                    })
                for updates_op in update_ops.values():
                    tf.add_to_collection('eval_ops', updates_op)

            precision = vars_to_summarize['metrics/precision']
            recall = vars_to_summarize['metrics/precision']
            f1_score = tf.where(
                tf.greater(precision + recall, 0),
                2 * ((precision * recall) / (precision + recall)), 0)
            vars_to_summarize['metrics/f1_score'] = f1_score
            for var_name, var_value in vars_to_summarize.iteritems():
                tf.summary.scalar(var_name, var_value)
                tf.add_to_collection(var_name, var_value)

        elif mode == 'generate':
            initial_state = rnn_nade.zero_state(hparams.batch_size)

            final_state = rnn_nade.steps(inputs, initial_state)
            samples, log_prob = rnn_nade.sample_single(initial_state)

            tf.add_to_collection('inputs', inputs)
            tf.add_to_collection('sample', samples)
            tf.add_to_collection('log_prob', log_prob)

            # Flatten state tuples for metagraph compatibility.
            for state in tf_nest.flatten(initial_state):
                tf.add_to_collection('initial_state', state)
            for state in tf_nest.flatten(final_state):
                tf.add_to_collection('final_state', state)
Beispiel #4
0
def build_graph(mode, config, sequence_example_file_paths=None):
    """Builds the TensorFlow graph.

  Args:
    mode: 'train', 'eval', or 'generate'. Only mode related ops are added to
        the graph.
    config: An EventSequenceRnnConfig containing the encoder/decoder and HParams
        to use.
    sequence_example_file_paths: A list of paths to TFRecord files containing
        tf.train.SequenceExample protos. Only needed for training and
        evaluation. May be a sharded file of the form.

  Returns:
    A tf.Graph instance which contains the TF ops.

  Raises:
    ValueError: If mode is not 'train', 'eval', or 'generate'.
  """
    if mode not in ('train', 'eval', 'generate'):
        raise ValueError("The mode parameter must be 'train', 'eval', "
                         "or 'generate'. The mode parameter was: %s" % mode)

    hparams = config.hparams
    encoder_decoder = config.encoder_decoder

    tf.logging.info('hparams = %s', hparams.values())

    input_size = encoder_decoder.input_size

    with tf.Graph().as_default() as graph:
        inputs, lengths = None, None

        if mode == 'train' or mode == 'eval':
            inputs, _, lengths = magenta.common.get_padded_batch(
                sequence_example_file_paths, hparams.batch_size, input_size)

        elif mode == 'generate':
            inputs = tf.placeholder(tf.float32,
                                    [hparams.batch_size, None, input_size])

        cell = events_rnn_graph.make_rnn_cell(
            hparams.rnn_layer_sizes,
            dropout_keep_prob=hparams.dropout_keep_prob
            if mode == 'train' else 1.0,
            attn_length=hparams.attn_length)

        rnn_nade = RnnNade(cell,
                           num_dims=input_size,
                           num_hidden=hparams.nade_hidden_units)

        if mode == 'train' or mode == 'eval':
            log_probs, cond_probs = rnn_nade.log_prob(inputs, lengths)

            inputs_flat = tf.to_float(
                magenta.common.flatten_maybe_padded_sequences(inputs, lengths))
            predictions_flat = tf.to_float(tf.greater_equal(cond_probs, .5))

            if mode == 'train':
                loss = tf.reduce_mean(-log_probs)
                perplexity = tf.reduce_mean(tf.exp(log_probs))
                correct_predictions = tf.to_float(
                    tf.equal(inputs_flat, predictions_flat))
                accuracy = tf.reduce_mean(correct_predictions)
                precision = (tf.reduce_sum(inputs_flat * predictions_flat) /
                             tf.reduce_sum(predictions_flat))
                recall = (tf.reduce_sum(inputs_flat * predictions_flat) /
                          tf.reduce_sum(inputs_flat))

                optimizer = tf.train.AdamOptimizer(
                    learning_rate=hparams.learning_rate)

                train_op = tf.contrib.slim.learning.create_train_op(
                    loss, optimizer, clip_gradient_norm=hparams.clip_norm)
                tf.add_to_collection('train_op', train_op)

                vars_to_summarize = {
                    'loss': loss,
                    'metrics/perplexity': perplexity,
                    'metrics/accuracy': accuracy,
                    'metrics/precision': precision,
                    'metrics/recall': recall,
                }
            elif mode == 'eval':
                vars_to_summarize, update_ops = tf.contrib.metrics.aggregate_metric_map(
                    {
                        'loss':
                        tf.metrics.mean(-log_probs),
                        'metrics/perplexity':
                        tf.metrics.mean(tf.exp(log_probs)),
                        'metrics/accuracy':
                        tf.metrics.accuracy(inputs_flat, predictions_flat),
                        'metrics/precision':
                        tf.metrics.precision(inputs_flat, predictions_flat),
                        'metrics/recall':
                        tf.metrics.recall(inputs_flat, predictions_flat),
                    })
                for updates_op in update_ops.values():
                    tf.add_to_collection('eval_ops', updates_op)

            precision = vars_to_summarize['metrics/precision']
            recall = vars_to_summarize['metrics/precision']
            f1_score = tf.where(
                tf.greater(precision + recall, 0),
                2 * ((precision * recall) / (precision + recall)), 0)
            vars_to_summarize['metrics/f1_score'] = f1_score
            for var_name, var_value in vars_to_summarize.iteritems():
                tf.summary.scalar(var_name, var_value)
                tf.add_to_collection(var_name, var_value)

        elif mode == 'generate':
            initial_state = rnn_nade.zero_state(hparams.batch_size)

            final_state = rnn_nade.steps(inputs, initial_state)
            samples, log_prob = rnn_nade.sample_single(initial_state)

            tf.add_to_collection('inputs', inputs)
            tf.add_to_collection('sample', samples)
            tf.add_to_collection('log_prob', log_prob)

            # Flatten state tuples for metagraph compatibility.
            for state in tf_nest.flatten(initial_state):
                tf.add_to_collection('initial_state', state)
            for state in tf_nest.flatten(final_state):
                tf.add_to_collection('final_state', state)

    return graph
  def build():
    """Builds the Tensorflow graph."""
    inputs, lengths = None, None

    if mode == 'train' or mode == 'eval':
      inputs, _, lengths = magenta.common.get_padded_batch(
          sequence_example_file_paths, hparams.batch_size, input_size,
          shuffle=mode == 'train')

    elif mode == 'generate':
      inputs = tf.placeholder(tf.float32,
                              [hparams.batch_size, None, input_size])

    cell = events_rnn_graph.make_rnn_cell(
        hparams.rnn_layer_sizes,
        dropout_keep_prob=hparams.dropout_keep_prob if mode == 'train' else 1.0,
        attn_length=(
            hparams.attn_length if hasattr(hparams, 'attn_length') else 0))

    rnn_nade = RnnNade(
        cell,
        num_dims=input_size,
        num_hidden=hparams.nade_hidden_units)

    if mode == 'train' or mode == 'eval':
      log_probs, cond_probs = rnn_nade.log_prob(inputs, lengths)

      inputs_flat = tf.to_float(
          magenta.common.flatten_maybe_padded_sequences(inputs, lengths))
      predictions_flat = tf.to_float(tf.greater_equal(cond_probs, .5))

      if mode == 'train':
        loss = tf.reduce_mean(-log_probs)
        perplexity = tf.reduce_mean(tf.exp(log_probs))
        correct_predictions = tf.to_float(
            tf.equal(inputs_flat, predictions_flat))
        accuracy = tf.reduce_mean(correct_predictions)
        precision = (tf.reduce_sum(inputs_flat * predictions_flat) /
                     tf.reduce_sum(predictions_flat))
        recall = (tf.reduce_sum(inputs_flat * predictions_flat) /
                  tf.reduce_sum(inputs_flat))

        optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)

        train_op = tf.contrib.slim.learning.create_train_op(
            loss, optimizer, clip_gradient_norm=hparams.clip_norm)
        tf.add_to_collection('train_op', train_op)

        vars_to_summarize = {
            'loss': loss,
            'metrics/perplexity': perplexity,
            'metrics/accuracy': accuracy,
            'metrics/precision': precision,
            'metrics/recall': recall,
        }
      elif mode == 'eval':
        vars_to_summarize, update_ops = tf.contrib.metrics.aggregate_metric_map(
            {
                'loss': tf.metrics.mean(-log_probs),
                'metrics/perplexity': tf.metrics.mean(tf.exp(log_probs)),
                'metrics/accuracy': tf.metrics.accuracy(
                    inputs_flat, predictions_flat),
                'metrics/precision': tf.metrics.precision(
                    inputs_flat, predictions_flat),
                'metrics/recall': tf.metrics.recall(
                    inputs_flat, predictions_flat),
            })
        for updates_op in update_ops.values():
          tf.add_to_collection('eval_ops', updates_op)

      precision = vars_to_summarize['metrics/precision']
      recall = vars_to_summarize['metrics/precision']
      f1_score = tf.where(
          tf.greater(precision + recall, 0), 2 * (
              (precision * recall) / (precision + recall)), 0)
      vars_to_summarize['metrics/f1_score'] = f1_score
      for var_name, var_value in vars_to_summarize.iteritems():
        tf.summary.scalar(var_name, var_value)
        tf.add_to_collection(var_name, var_value)

    elif mode == 'generate':
      initial_state = rnn_nade.zero_state(hparams.batch_size)

      final_state = rnn_nade.steps(inputs, initial_state)
      samples, log_prob = rnn_nade.sample_single(initial_state)

      tf.add_to_collection('inputs', inputs)
      tf.add_to_collection('sample', samples)
      tf.add_to_collection('log_prob', log_prob)

      # Flatten state tuples for metagraph compatibility.
      for state in tf_nest.flatten(initial_state):
        tf.add_to_collection('initial_state', state)
      for state in tf_nest.flatten(final_state):
        tf.add_to_collection('final_state', state)
def build_graph(mode, config, sequence_example_file_paths=None):
  """Builds the TensorFlow graph.

  Args:
    mode: 'train', 'eval', or 'generate'. Only mode related ops are added to
        the graph.
    config: An EventSequenceRnnConfig containing the encoder/decoder and HParams
        to use.
    sequence_example_file_paths: A list of paths to TFRecord files containing
        tf.train.SequenceExample protos. Only needed for training and
        evaluation. May be a sharded file of the form.

  Returns:
    A tf.Graph instance which contains the TF ops.

  Raises:
    ValueError: If mode is not 'train', 'eval', or 'generate'.
  """
  if mode not in ('train', 'eval', 'generate'):
    raise ValueError("The mode parameter must be 'train', 'eval', "
                     "or 'generate'. The mode parameter was: %s" % mode)

  hparams = config.hparams
  encoder_decoder = config.encoder_decoder

  tf.logging.info('hparams = %s', hparams.values())

  input_size = encoder_decoder.input_size

  with tf.Graph().as_default() as graph:
    inputs, lengths = None, None

    if mode == 'train' or mode == 'eval':
      inputs, _, lengths = magenta.common.get_padded_batch(
          sequence_example_file_paths, hparams.batch_size, input_size)

    elif mode == 'generate':
      inputs = tf.placeholder(tf.float32,
                              [hparams.batch_size, None, input_size])

    cell = events_rnn_graph.make_rnn_cell(
        hparams.rnn_layer_sizes,
        dropout_keep_prob=hparams.dropout_keep_prob if mode == 'train' else 1.0,
        attn_length=hparams.attn_length)

    rnn_nade = RnnNade(
        cell,
        num_dims=input_size,
        num_hidden=hparams.nade_hidden_units)

    if mode == 'train' or mode == 'eval':
      log_probs, cond_probs = rnn_nade.log_prob(inputs, lengths)

      inputs_flat = flatten_maybe_padded_sequences(inputs, lengths)
      predictions_flat = tf.cast(tf.greater_equal(cond_probs, .5), tf.float32)

      loss = tf.reduce_mean(-log_probs)

      perplexity = tf.reduce_mean(tf.exp(log_probs)) * 100
      accuracy = tf.reduce_mean(
          tf.to_float(tf.equal(inputs_flat, predictions_flat))) * 100
      accuracy_without_true_negatives = (
          tf.reduce_sum(inputs_flat *
                        tf.to_float(tf.equal(predictions_flat, inputs_flat))) /
          tf.reduce_sum(inputs_flat) * 100)

      global_step = tf.contrib.framework.get_or_create_global_step()

      tf.add_to_collection('loss', loss)
      tf.add_to_collection('perplexity', perplexity)
      tf.add_to_collection('accuracy', accuracy)
      tf.add_to_collection('accuracy_without_true_negatives',
                           accuracy_without_true_negatives)

      summaries = [
          tf.summary.scalar('loss', loss),
          tf.summary.scalar('perplexity', perplexity),
          tf.summary.scalar('accuracy', accuracy),
          tf.summary.scalar('accuracy_without_true_negatives',
                            accuracy_without_true_negatives),
      ]

      if mode == 'train':
        learning_rate = tf.train.exponential_decay(
            hparams.initial_learning_rate, global_step, hparams.decay_steps,
            hparams.decay_rate, staircase=True, name='learning_rate')

        opt = tf.train.AdamOptimizer(learning_rate)
        params = tf.trainable_variables()
        gradients = tf.gradients(loss, params)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients,
                                                      hparams.clip_norm)
        train_op = opt.apply_gradients(zip(clipped_gradients, params),
                                       global_step)
        tf.add_to_collection('learning_rate', learning_rate)
        tf.add_to_collection('train_op', train_op)

        summaries.append(tf.summary.scalar(
            'learning_rate', learning_rate))

      if mode == 'eval':
        summary_op = tf.summary.merge(summaries)
        tf.add_to_collection('summary_op', summary_op)

    elif mode == 'generate':
      initial_state = rnn_nade.zero_state(hparams.batch_size)

      final_state = rnn_nade.steps(inputs, initial_state)
      samples, log_prob = rnn_nade.sample_single(initial_state)

      tf.add_to_collection('inputs', inputs)
      tf.add_to_collection('sample', samples)
      tf.add_to_collection('log_prob', log_prob)

      # Flatten state tuples for metagraph compatibility.
      for state in tf_nest.flatten(initial_state):
        tf.add_to_collection('initial_state', state)
      for state in tf_nest.flatten(final_state):
        tf.add_to_collection('final_state', state)

  return graph
def build_graph(mode, config, sequence_example_file_paths=None):
    """Builds the TensorFlow graph.

  Args:
    mode: 'train', 'eval', or 'generate'. Only mode related ops are added to
        the graph.
    config: An EventSequenceRnnConfig containing the encoder/decoder and HParams
        to use.
    sequence_example_file_paths: A list of paths to TFRecord files containing
        tf.train.SequenceExample protos. Only needed for training and
        evaluation. May be a sharded file of the form.

  Returns:
    A tf.Graph instance which contains the TF ops.

  Raises:
    ValueError: If mode is not 'train', 'eval', or 'generate'.
  """
    if mode not in ('train', 'eval', 'generate'):
        raise ValueError("The mode parameter must be 'train', 'eval', "
                         "or 'generate'. The mode parameter was: %s" % mode)

    hparams = config.hparams
    encoder_decoder = config.encoder_decoder

    tf.logging.info('hparams = %s', hparams.values())

    input_size = encoder_decoder.input_size

    with tf.Graph().as_default() as graph:
        inputs, lengths = None, None

        if mode == 'train' or mode == 'eval':
            inputs, _, lengths = magenta.common.get_padded_batch(
                sequence_example_file_paths, hparams.batch_size, input_size)

        elif mode == 'generate':
            inputs = tf.placeholder(tf.float32,
                                    [hparams.batch_size, None, input_size])

        cell = events_rnn_graph.make_rnn_cell(
            hparams.rnn_layer_sizes,
            dropout_keep_prob=hparams.dropout_keep_prob
            if mode == 'train' else 1.0,
            attn_length=hparams.attn_length)

        rnn_nade = RnnNade(cell,
                           num_dims=input_size,
                           num_hidden=hparams.nade_hidden_units)

        if mode == 'train' or mode == 'eval':
            log_probs, cond_probs = rnn_nade.log_prob(inputs, lengths)

            inputs_flat = flatten_maybe_padded_sequences(inputs, lengths)
            predictions_flat = tf.cast(tf.greater_equal(cond_probs, .5),
                                       tf.float32)

            loss = tf.reduce_mean(-log_probs)

            perplexity = tf.reduce_mean(tf.exp(log_probs)) * 100
            accuracy = tf.reduce_mean(
                tf.to_float(tf.equal(inputs_flat, predictions_flat))) * 100
            accuracy_without_true_negatives = (tf.reduce_sum(
                inputs_flat *
                tf.to_float(tf.equal(predictions_flat, inputs_flat))) /
                                               tf.reduce_sum(inputs_flat) *
                                               100)

            global_step = tf.contrib.framework.get_or_create_global_step()

            tf.add_to_collection('loss', loss)
            tf.add_to_collection('perplexity', perplexity)
            tf.add_to_collection('accuracy', accuracy)
            tf.add_to_collection('accuracy_without_true_negatives',
                                 accuracy_without_true_negatives)

            summaries = [
                tf.summary.scalar('loss', loss),
                tf.summary.scalar('perplexity', perplexity),
                tf.summary.scalar('accuracy', accuracy),
                tf.summary.scalar('accuracy_without_true_negatives',
                                  accuracy_without_true_negatives),
            ]

            if mode == 'train':
                learning_rate = tf.train.exponential_decay(
                    hparams.initial_learning_rate,
                    global_step,
                    hparams.decay_steps,
                    hparams.decay_rate,
                    staircase=True,
                    name='learning_rate')

                opt = tf.train.AdamOptimizer(learning_rate)
                params = tf.trainable_variables()
                gradients = tf.gradients(loss, params)
                clipped_gradients, _ = tf.clip_by_global_norm(
                    gradients, hparams.clip_norm)
                train_op = opt.apply_gradients(zip(clipped_gradients, params),
                                               global_step)
                tf.add_to_collection('learning_rate', learning_rate)
                tf.add_to_collection('train_op', train_op)

                summaries.append(
                    tf.summary.scalar('learning_rate', learning_rate))

            if mode == 'eval':
                summary_op = tf.summary.merge(summaries)
                tf.add_to_collection('summary_op', summary_op)

        elif mode == 'generate':
            initial_state = rnn_nade.zero_state(hparams.batch_size)

            final_state = rnn_nade.steps(inputs, initial_state)
            samples, log_prob = rnn_nade.sample_single(initial_state)

            tf.add_to_collection('inputs', inputs)
            tf.add_to_collection('sample', samples)
            tf.add_to_collection('log_prob', log_prob)

            # Flatten state tuples for metagraph compatibility.
            for state in tf_nest.flatten(initial_state):
                tf.add_to_collection('initial_state', state)
            for state in tf_nest.flatten(final_state):
                tf.add_to_collection('final_state', state)

    return graph
Beispiel #8
0
  def build_graph(self):
    """Constructs the portion of the graph that belongs to this model."""

    tf.logging.info('Initializing melody RNN graph for scope %s', self.scope)

    with self.graph.as_default():
      with tf.device(lambda op: ''):
        with tf.variable_scope(self.scope):
          # Make an LSTM cell with the number and size of layers specified in
          # hparams.
          if self.note_rnn_type == 'basic_rnn':
            self.cell = events_rnn_graph.make_rnn_cell(
                self.hparams.rnn_layer_sizes)
          else:
            self.cell = rl_tuner_ops.make_rnn_cell(self.hparams.rnn_layer_sizes)
          # Shape of melody_sequence is batch size, melody length, number of
          # output note actions.
          self.melody_sequence = tf.placeholder(tf.float32,
                                                [None, None,
                                                 self.hparams.one_hot_length],
                                                name='melody_sequence')
          self.lengths = tf.placeholder(tf.int32, [None], name='lengths')
          self.initial_state = tf.placeholder(tf.float32,
                                              [None, self.cell.state_size],
                                              name='initial_state')

          if self.training_file_list is not None:
            # Set up a tf queue to read melodies from the training data tfrecord
            (self.train_sequence,
             self.train_labels,
             self.train_lengths) = sequence_example_lib.get_padded_batch(
                 self.training_file_list, self.hparams.batch_size,
                 self.hparams.one_hot_length)

          # Closure function is used so that this part of the graph can be
          # re-run in multiple places, such as __call__.
          def run_network_on_melody(m_seq,
                                    lens,
                                    initial_state,
                                    swap_memory=True,
                                    parallel_iterations=1):
            """Internal function that defines the RNN network structure.

            Args:
              m_seq: A batch of melody sequences of one-hot notes.
              lens: Lengths of the melody_sequences.
              initial_state: Vector representing the initial state of the RNN.
              swap_memory: Uses more memory and is faster.
              parallel_iterations: Argument to tf.nn.dynamic_rnn.
            Returns:
              Output of network (either softmax or logits) and RNN state.
            """
            outputs, final_state = tf.nn.dynamic_rnn(
                self.cell,
                m_seq,
                sequence_length=lens,
                initial_state=initial_state,
                swap_memory=swap_memory,
                parallel_iterations=parallel_iterations)

            outputs_flat = tf.reshape(outputs,
                                      [-1, self.hparams.rnn_layer_sizes[-1]])
            if self.note_rnn_type == 'basic_rnn':
              linear_layer = tf.contrib.layers.linear
            else:
              linear_layer = tf.contrib.layers.legacy_linear
            logits_flat = linear_layer(
                outputs_flat, self.hparams.one_hot_length)
            return logits_flat, final_state

          (self.logits, self.state_tensor) = run_network_on_melody(
              self.melody_sequence, self.lengths, self.initial_state)
          self.softmax = tf.nn.softmax(self.logits)

          self.run_network_on_melody = run_network_on_melody

        if self.training_file_list is not None:
          # Does not recreate the model architecture but rather uses it to feed
          # data from the training queue through the model.
          with tf.variable_scope(self.scope, reuse=True):
            zero_state = self.cell.zero_state(
                batch_size=self.hparams.batch_size, dtype=tf.float32)

            (self.train_logits, self.train_state) = run_network_on_melody(
                self.train_sequence, self.train_lengths, zero_state)
            self.train_softmax = tf.nn.softmax(self.train_logits)