Example #1
0
def model_fn(features, labels, mode, params, config):
    """Builds the acoustic model."""
    del config
    hparams = params

    length = features.length
    spec = features.spec

    is_training = mode == tf_estimator.ModeKeys.TRAIN

    if is_training:
        onset_labels = labels.onsets
        offset_labels = labels.offsets
        velocity_labels = labels.velocities
        frame_labels = labels.labels
        frame_label_weights = labels.label_weights

    if hparams.stop_activation_gradient and not hparams.activation_loss:
        raise ValueError(
            'If stop_activation_gradient is true, activation_loss must be true.'
        )

    losses = {}
    with slim.arg_scope([slim.batch_norm, slim.dropout],
                        is_training=is_training):
        with tf.variable_scope('onsets'):
            onset_outputs = acoustic_model(spec,
                                           hparams,
                                           lstm_units=hparams.onset_lstm_units,
                                           lengths=length)
            onset_probs = slim.fully_connected(onset_outputs,
                                               constants.MIDI_PITCHES,
                                               activation_fn=tf.sigmoid,
                                               scope='onset_probs')

            # onset_probs_flat is used during inference.
            onset_probs_flat = flatten_maybe_padded_sequences(
                onset_probs, length)
            if is_training:
                onset_labels_flat = flatten_maybe_padded_sequences(
                    onset_labels, length)
                onset_losses = tf_utils.log_loss(onset_labels_flat,
                                                 onset_probs_flat)
                tf.losses.add_loss(tf.reduce_mean(onset_losses))
                losses['onset'] = onset_losses
        with tf.variable_scope('offsets'):
            offset_outputs = acoustic_model(
                spec,
                hparams,
                lstm_units=hparams.offset_lstm_units,
                lengths=length)
            offset_probs = slim.fully_connected(offset_outputs,
                                                constants.MIDI_PITCHES,
                                                activation_fn=tf.sigmoid,
                                                scope='offset_probs')

            # offset_probs_flat is used during inference.
            offset_probs_flat = flatten_maybe_padded_sequences(
                offset_probs, length)
            if is_training:
                offset_labels_flat = flatten_maybe_padded_sequences(
                    offset_labels, length)
                offset_losses = tf_utils.log_loss(offset_labels_flat,
                                                  offset_probs_flat)
                tf.losses.add_loss(tf.reduce_mean(offset_losses))
                losses['offset'] = offset_losses
        with tf.variable_scope('velocity'):
            velocity_outputs = acoustic_model(
                spec,
                hparams,
                lstm_units=hparams.velocity_lstm_units,
                lengths=length)
            velocity_values = slim.fully_connected(velocity_outputs,
                                                   constants.MIDI_PITCHES,
                                                   activation_fn=None,
                                                   scope='onset_velocities')

            velocity_values_flat = flatten_maybe_padded_sequences(
                velocity_values, length)
            if is_training:
                velocity_labels_flat = flatten_maybe_padded_sequences(
                    velocity_labels, length)
                velocity_loss = tf.reduce_sum(
                    onset_labels_flat *
                    tf.square(velocity_labels_flat - velocity_values_flat),
                    axis=1)
                tf.losses.add_loss(tf.reduce_mean(velocity_loss))
                losses['velocity'] = velocity_loss

        with tf.variable_scope('frame'):
            if not hparams.share_conv_features:
                # TODO(eriche): this is broken when hparams.frame_lstm_units > 0
                activation_outputs = acoustic_model(
                    spec,
                    hparams,
                    lstm_units=hparams.frame_lstm_units,
                    lengths=length)
                activation_probs = slim.fully_connected(
                    activation_outputs,
                    constants.MIDI_PITCHES,
                    activation_fn=tf.sigmoid,
                    scope='activation_probs')
            else:
                activation_probs = slim.fully_connected(
                    onset_outputs,
                    constants.MIDI_PITCHES,
                    activation_fn=tf.sigmoid,
                    scope='activation_probs')

            probs = []
            if hparams.stop_onset_gradient:
                probs.append(tf.stop_gradient(onset_probs))
            else:
                probs.append(onset_probs)

            if hparams.stop_activation_gradient:
                probs.append(tf.stop_gradient(activation_probs))
            else:
                probs.append(activation_probs)

            if hparams.stop_offset_gradient:
                probs.append(tf.stop_gradient(offset_probs))
            else:
                probs.append(offset_probs)

            combined_probs = tf.concat(probs, 2)

            if hparams.combined_lstm_units > 0:
                outputs = lstm_layer(
                    combined_probs,
                    hparams.combined_lstm_units,
                    lengths=length if hparams.use_lengths else None,
                    stack_size=hparams.combined_rnn_stack_size,
                    use_cudnn=hparams.use_cudnn,
                    bidirectional=hparams.bidirectional)
            else:
                outputs = combined_probs

            frame_probs = slim.fully_connected(outputs,
                                               constants.MIDI_PITCHES,
                                               activation_fn=tf.sigmoid,
                                               scope='frame_probs')

        frame_probs_flat = flatten_maybe_padded_sequences(frame_probs, length)

        if is_training:
            frame_labels_flat = flatten_maybe_padded_sequences(
                frame_labels, length)
            frame_label_weights_flat = flatten_maybe_padded_sequences(
                frame_label_weights, length)
            if hparams.weight_frame_and_activation_loss:
                frame_loss_weights = frame_label_weights_flat
            else:
                frame_loss_weights = None
            frame_losses = tf_utils.log_loss(frame_labels_flat,
                                             frame_probs_flat,
                                             weights=frame_loss_weights)
            tf.losses.add_loss(tf.reduce_mean(frame_losses))
            losses['frame'] = frame_losses

            if hparams.activation_loss:
                if hparams.weight_frame_and_activation_loss:
                    activation_loss_weights = frame_label_weights
                else:
                    activation_loss_weights = None
                activation_losses = tf_utils.log_loss(
                    frame_labels_flat,
                    flatten_maybe_padded_sequences(activation_probs, length),
                    weights=activation_loss_weights)
                tf.losses.add_loss(tf.reduce_mean(activation_losses))
                losses['activation'] = activation_losses

    frame_predictions = frame_probs_flat > hparams.predict_frame_threshold
    onset_predictions = onset_probs_flat > hparams.predict_onset_threshold
    offset_predictions = offset_probs_flat > hparams.predict_offset_threshold

    frame_predictions = tf.expand_dims(frame_predictions, axis=0)
    onset_predictions = tf.expand_dims(onset_predictions, axis=0)
    offset_predictions = tf.expand_dims(offset_predictions, axis=0)
    velocity_values = tf.expand_dims(velocity_values_flat, axis=0)

    metrics_values = metrics.define_metrics(
        frame_probs=frame_probs,
        onset_probs=onset_probs,
        frame_predictions=frame_predictions,
        onset_predictions=onset_predictions,
        offset_predictions=offset_predictions,
        velocity_values=velocity_values,
        length=features.length,
        sequence_label=labels.note_sequence,
        frame_labels=labels.labels,
        sequence_id=features.sequence_id,
        hparams=hparams)

    for label, loss_collection in losses.items():
        loss_label = 'losses/' + label
        metrics_values[loss_label] = loss_collection

    def predict_sequence():
        """Convert frame predictions into a sequence (TF)."""
        def _predict(frame_probs, onset_probs, frame_predictions,
                     onset_predictions, offset_predictions, velocity_values):
            """Convert frame predictions into a sequence (Python)."""
            sequence = infer_util.predict_sequence(
                frame_probs=frame_probs,
                onset_probs=onset_probs,
                frame_predictions=frame_predictions,
                onset_predictions=onset_predictions,
                offset_predictions=offset_predictions,
                velocity_values=velocity_values,
                hparams=hparams,
                min_pitch=constants.MIN_MIDI_PITCH)
            return sequence.SerializeToString()

        sequence = tf.py_func(_predict,
                              inp=[
                                  frame_probs[0],
                                  onset_probs[0],
                                  frame_predictions[0],
                                  onset_predictions[0],
                                  offset_predictions[0],
                                  velocity_values[0],
                              ],
                              Tout=tf.string,
                              stateful=False)
        sequence.set_shape([])
        return tf.expand_dims(sequence, axis=0)

    predictions = {
        'frame_probs': frame_probs,
        'onset_probs': onset_probs,
        'frame_predictions': frame_predictions,
        'onset_predictions': onset_predictions,
        'offset_predictions': offset_predictions,
        'velocity_values': velocity_values,
        'sequence_predictions': predict_sequence(),
        # Include some features and labels in output because Estimator 'predict'
        # API does not give access to them.
        'sequence_ids': features.sequence_id,
        'sequence_labels': labels.note_sequence,
        'frame_labels': labels.labels,
        'onset_labels': labels.onsets,
    }
    for k, v in metrics_values.items():
        predictions[k] = tf.stack(v)

    metric_ops = {k: tf.metrics.mean(v) for k, v in metrics_values.items()}

    train_op = None
    loss = None
    if is_training:
        # Creates a pianoroll labels in red and probs in green [minibatch, 88]
        images = {}
        onset_pianorolls = tf.concat([
            onset_labels[:, :, :, tf.newaxis], onset_probs[:, :, :,
                                                           tf.newaxis],
            tf.zeros(tf.shape(onset_labels))[:, :, :, tf.newaxis]
        ],
                                     axis=3)
        images['OnsetPianorolls'] = onset_pianorolls
        offset_pianorolls = tf.concat([
            offset_labels[:, :, :, tf.newaxis], offset_probs[:, :, :,
                                                             tf.newaxis],
            tf.zeros(tf.shape(offset_labels))[:, :, :, tf.newaxis]
        ],
                                      axis=3)
        images['OffsetPianorolls'] = offset_pianorolls
        activation_pianorolls = tf.concat([
            frame_labels[:, :, :, tf.newaxis], frame_probs[:, :, :,
                                                           tf.newaxis],
            tf.zeros(tf.shape(frame_labels))[:, :, :, tf.newaxis]
        ],
                                          axis=3)
        images['ActivationPianorolls'] = activation_pianorolls
        for name, image in images.items():
            tf.summary.image(name, image)

        loss = tf.losses.get_total_loss()
        tf.summary.scalar('loss', loss)
        for label, loss_collection in losses.items():
            loss_label = 'losses/' + label
            tf.summary.scalar(loss_label, tf.reduce_mean(loss_collection))

        train_op = slim.optimize_loss(
            name='training',
            loss=loss,
            global_step=tf.train.get_or_create_global_step(),
            learning_rate=hparams.learning_rate,
            learning_rate_decay_fn=functools.partial(
                tf.train.exponential_decay,
                decay_steps=hparams.decay_steps,
                decay_rate=hparams.decay_rate,
                staircase=True),
            clip_gradients=hparams.clip_norm,
            optimizer='Adam')

    return tf_estimator.EstimatorSpec(mode=mode,
                                      predictions=predictions,
                                      loss=loss,
                                      train_op=train_op,
                                      eval_metric_ops=metric_ops)
Example #2
0
def get_model(transcription_data, hparams, is_training=True):
  """Builds the acoustic model."""
  onset_labels = transcription_data.onsets
  velocity_labels = transcription_data.velocities
  frame_labels = transcription_data.labels
  frame_label_weights = transcription_data.label_weights
  lengths = transcription_data.lengths
  spec = transcription_data.spec

  if hparams.stop_activation_gradient and not hparams.activation_loss:
    raise ValueError(
        'If stop_activation_gradient is true, activation_loss must be true.')

  losses = {}
  with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training):
    with tf.variable_scope('onsets'):
      onset_outputs = acoustic_model(
          spec, hparams, lstm_units=hparams.onset_lstm_units, lengths=lengths)
      onset_probs = slim.fully_connected(
          onset_outputs,
          constants.MIDI_PITCHES,
          activation_fn=tf.sigmoid,
          scope='onset_probs')

      # onset_probs_flat is used during inference.
      onset_probs_flat = flatten_maybe_padded_sequences(onset_probs, lengths)
      onset_labels_flat = flatten_maybe_padded_sequences(onset_labels, lengths)
      tf.identity(onset_probs_flat, name='onset_probs_flat')
      tf.identity(onset_labels_flat, name='onset_labels_flat')
      tf.identity(
          tf.cast(tf.greater_equal(onset_probs_flat, .5), tf.float32),
          name='onset_predictions_flat')

      onset_losses = tf_utils.log_loss(onset_labels_flat, onset_probs_flat)
      tf.losses.add_loss(tf.reduce_mean(onset_losses))
      losses['onset'] = onset_losses

    with tf.variable_scope('velocity'):
      # TODO(eriche): this is broken when hparams.velocity_lstm_units > 0
      velocity_outputs = acoustic_model(
          spec,
          hparams,
          lstm_units=hparams.velocity_lstm_units,
          lengths=lengths)
      velocity_values = slim.fully_connected(
          velocity_outputs,
          constants.MIDI_PITCHES,
          activation_fn=None,
          scope='onset_velocities')

      velocity_values_flat = flatten_maybe_padded_sequences(
          velocity_values, lengths)
      tf.identity(velocity_values_flat, name='velocity_values_flat')
      velocity_labels_flat = flatten_maybe_padded_sequences(
          velocity_labels, lengths)
      velocity_loss = tf.reduce_sum(
          onset_labels_flat *
          tf.square(velocity_labels_flat - velocity_values_flat),
          axis=1)
      tf.losses.add_loss(tf.reduce_mean(velocity_loss))
      losses['velocity'] = velocity_loss

    with tf.variable_scope('frame'):
      if not hparams.share_conv_features:
        # TODO(eriche): this is broken when hparams.frame_lstm_units > 0
        activation_outputs = acoustic_model(
            spec, hparams, lstm_units=hparams.frame_lstm_units, lengths=lengths)
        activation_probs = slim.fully_connected(
            activation_outputs,
            constants.MIDI_PITCHES,
            activation_fn=tf.sigmoid,
            scope='activation_probs')
      else:
        activation_probs = slim.fully_connected(
            onset_outputs,
            constants.MIDI_PITCHES,
            activation_fn=tf.sigmoid,
            scope='activation_probs')

      combined_probs = tf.concat([
          tf.stop_gradient(onset_probs)
          if hparams.stop_onset_gradient else onset_probs,
          tf.stop_gradient(activation_probs)
          if hparams.stop_activation_gradient else activation_probs
      ], 2)

      if hparams.combined_lstm_units > 0:
        rnn_cell_fw = tf.contrib.rnn.LSTMBlockCell(hparams.combined_lstm_units)
        if hparams.frame_bidirectional:
          rnn_cell_bw = tf.contrib.rnn.LSTMBlockCell(
              hparams.combined_lstm_units)
          outputs, unused_output_states = tf.nn.bidirectional_dynamic_rnn(
              rnn_cell_fw, rnn_cell_bw, inputs=combined_probs, dtype=tf.float32)
          combined_outputs = tf.concat(outputs, 2)
        else:
          combined_outputs, unused_output_states = tf.nn.dynamic_rnn(
              rnn_cell_fw, inputs=combined_probs, dtype=tf.float32)
      else:
        combined_outputs = combined_probs

      frame_probs = slim.fully_connected(
          combined_outputs,
          constants.MIDI_PITCHES,
          activation_fn=tf.sigmoid,
          scope='frame_probs')

    frame_labels_flat = flatten_maybe_padded_sequences(frame_labels, lengths)
    frame_probs_flat = flatten_maybe_padded_sequences(frame_probs, lengths)
    tf.identity(frame_probs_flat, name='frame_probs_flat')
    frame_label_weights_flat = flatten_maybe_padded_sequences(
        frame_label_weights, lengths)
    frame_losses = tf_utils.log_loss(
        frame_labels_flat,
        frame_probs_flat,
        weights=frame_label_weights_flat
        if hparams.weight_frame_and_activation_loss else None)
    tf.losses.add_loss(tf.reduce_mean(frame_losses))
    losses['frame'] = frame_losses

    if hparams.activation_loss:
      activation_losses = tf_utils.log_loss(
          frame_labels_flat,
          flatten_maybe_padded_sequences(activation_probs, lengths),
          weights=frame_label_weights_flat
          if hparams.weight_frame_and_activation_loss else None)
      tf.losses.add_loss(tf.reduce_mean(activation_losses))
      losses['activation'] = activation_losses

  predictions_flat = tf.cast(tf.greater_equal(frame_probs_flat, .5), tf.float32)

  # Creates a pianoroll labels in red and probs in green [minibatch, 88]
  images = {}
  onset_pianorolls = tf.concat(
      [
          onset_labels[:, :, :, tf.newaxis], onset_probs[:, :, :, tf.newaxis],
          tf.zeros(tf.shape(onset_labels))[:, :, :, tf.newaxis]
      ],
      axis=3)
  images['OnsetPianorolls'] = onset_pianorolls
  activation_pianorolls = tf.concat(
      [
          frame_labels[:, :, :, tf.newaxis], frame_probs[:, :, :, tf.newaxis],
          tf.zeros(tf.shape(frame_labels))[:, :, :, tf.newaxis]
      ],
      axis=3)
  images['ActivationPianorolls'] = activation_pianorolls

  return (tf.losses.get_total_loss(), losses, frame_labels_flat,
          predictions_flat, images)
Example #3
0
def get_model(transcription_data, hparams, is_training=True):
    """Builds the acoustic model."""
    onset_labels = transcription_data.onsets
    velocity_labels = transcription_data.velocities
    frame_labels = transcription_data.labels
    frame_label_weights = transcription_data.label_weights
    lengths = transcription_data.lengths
    spec = transcription_data.spec

    if hparams.stop_activation_gradient and not hparams.activation_loss:
        raise ValueError(
            'If stop_activation_gradient is true, activation_loss must be true.'
        )

    losses = {}
    with slim.arg_scope([slim.batch_norm, slim.dropout],
                        is_training=is_training):
        with tf.variable_scope('onsets'):
            onset_outputs = acoustic_model(spec,
                                           hparams,
                                           lstm_units=hparams.onset_lstm_units,
                                           lengths=lengths)
            onset_probs = slim.fully_connected(onset_outputs,
                                               constants.MIDI_PITCHES,
                                               activation_fn=tf.sigmoid,
                                               scope='onset_probs')

            # onset_probs_flat is used during inference.
            onset_probs_flat = flatten_maybe_padded_sequences(
                onset_probs, lengths)
            onset_labels_flat = flatten_maybe_padded_sequences(
                onset_labels, lengths)
            tf.identity(onset_probs_flat, name='onset_probs_flat')
            tf.identity(onset_labels_flat, name='onset_labels_flat')
            tf.identity(tf.cast(tf.greater_equal(onset_probs_flat, .5),
                                tf.float32),
                        name='onset_predictions_flat')

            onset_losses = tf_utils.log_loss(onset_labels_flat,
                                             onset_probs_flat)
            tf.losses.add_loss(tf.reduce_mean(onset_losses))
            losses['onset'] = onset_losses

        with tf.variable_scope('velocity'):
            # TODO(eriche): this is broken when hparams.velocity_lstm_units > 0
            velocity_outputs = acoustic_model(
                spec,
                hparams,
                lstm_units=hparams.velocity_lstm_units,
                lengths=lengths)
            velocity_values = slim.fully_connected(velocity_outputs,
                                                   constants.MIDI_PITCHES,
                                                   activation_fn=None,
                                                   scope='onset_velocities')

            velocity_values_flat = flatten_maybe_padded_sequences(
                velocity_values, lengths)
            tf.identity(velocity_values_flat, name='velocity_values_flat')
            velocity_labels_flat = flatten_maybe_padded_sequences(
                velocity_labels, lengths)
            velocity_loss = tf.reduce_sum(
                onset_labels_flat *
                tf.square(velocity_labels_flat - velocity_values_flat),
                axis=1)
            tf.losses.add_loss(tf.reduce_mean(velocity_loss))
            losses['velocity'] = velocity_loss

        with tf.variable_scope('frame'):
            if not hparams.share_conv_features:
                # TODO(eriche): this is broken when hparams.frame_lstm_units > 0
                activation_outputs = acoustic_model(
                    spec,
                    hparams,
                    lstm_units=hparams.frame_lstm_units,
                    lengths=lengths)
                activation_probs = slim.fully_connected(
                    activation_outputs,
                    constants.MIDI_PITCHES,
                    activation_fn=tf.sigmoid,
                    scope='activation_probs')
            else:
                activation_probs = slim.fully_connected(
                    onset_outputs,
                    constants.MIDI_PITCHES,
                    activation_fn=tf.sigmoid,
                    scope='activation_probs')

            combined_probs = tf.concat([
                tf.stop_gradient(onset_probs)
                if hparams.stop_onset_gradient else onset_probs,
                tf.stop_gradient(activation_probs)
                if hparams.stop_activation_gradient else activation_probs
            ], 2)

            if hparams.combined_lstm_units > 0:
                rnn_cell_fw = tf.contrib.rnn.LSTMBlockCell(
                    hparams.combined_lstm_units)
                if hparams.frame_bidirectional:
                    rnn_cell_bw = tf.contrib.rnn.LSTMBlockCell(
                        hparams.combined_lstm_units)
                    outputs, unused_output_states = tf.nn.bidirectional_dynamic_rnn(
                        rnn_cell_fw,
                        rnn_cell_bw,
                        inputs=combined_probs,
                        dtype=tf.float32)
                    combined_outputs = tf.concat(outputs, 2)
                else:
                    combined_outputs, unused_output_states = tf.nn.dynamic_rnn(
                        rnn_cell_fw, inputs=combined_probs, dtype=tf.float32)
            else:
                combined_outputs = combined_probs

            frame_probs = slim.fully_connected(combined_outputs,
                                               constants.MIDI_PITCHES,
                                               activation_fn=tf.sigmoid,
                                               scope='frame_probs')

        frame_labels_flat = flatten_maybe_padded_sequences(
            frame_labels, lengths)
        frame_probs_flat = flatten_maybe_padded_sequences(frame_probs, lengths)
        tf.identity(frame_probs_flat, name='frame_probs_flat')
        frame_label_weights_flat = flatten_maybe_padded_sequences(
            frame_label_weights, lengths)
        frame_losses = tf_utils.log_loss(
            frame_labels_flat,
            frame_probs_flat,
            weights=frame_label_weights_flat
            if hparams.weight_frame_and_activation_loss else None)
        tf.losses.add_loss(tf.reduce_mean(frame_losses))
        losses['frame'] = frame_losses

        if hparams.activation_loss:
            activation_losses = tf_utils.log_loss(
                frame_labels_flat,
                flatten_maybe_padded_sequences(activation_probs, lengths),
                weights=frame_label_weights_flat
                if hparams.weight_frame_and_activation_loss else None)
            tf.losses.add_loss(tf.reduce_mean(activation_losses))
            losses['activation'] = activation_losses

    predictions_flat = tf.cast(tf.greater_equal(frame_probs_flat, .5),
                               tf.float32)

    # Creates a pianoroll labels in red and probs in green [minibatch, 88]
    images = {}
    onset_pianorolls = tf.concat([
        onset_labels[:, :, :, tf.newaxis], onset_probs[:, :, :, tf.newaxis],
        tf.zeros(tf.shape(onset_labels))[:, :, :, tf.newaxis]
    ],
                                 axis=3)
    images['OnsetPianorolls'] = onset_pianorolls
    activation_pianorolls = tf.concat([
        frame_labels[:, :, :, tf.newaxis], frame_probs[:, :, :, tf.newaxis],
        tf.zeros(tf.shape(frame_labels))[:, :, :, tf.newaxis]
    ],
                                      axis=3)
    images['ActivationPianorolls'] = activation_pianorolls

    return (tf.losses.get_total_loss(), losses, frame_labels_flat,
            predictions_flat, images)
Example #4
0
def model_fn(features, labels, mode, params, config):
    """Builds the acoustic model."""
    del config
    hparams = params

    length = features.length
    spec = features.spec

    is_training = mode == tf.estimator.ModeKeys.TRAIN

    if is_training:
        onset_labels = labels.onsets
        offset_labels = labels.offsets
        velocity_labels = labels.velocities
        frame_labels = labels.labels
        frame_label_weights = labels.label_weights

    if hparams.stop_activation_gradient and not hparams.activation_loss:
        raise ValueError(
            'If stop_activation_gradient is true, activation_loss must be true.'
        )

    losses = {}
    with slim.arg_scope([slim.batch_norm, slim.dropout],
                        is_training=is_training):
        with tf.variable_scope('onsets'):
            onset_outputs = acoustic_model(spec,
                                           hparams,
                                           lstm_units=hparams.onset_lstm_units,
                                           lengths=length,
                                           is_training=is_training)
            onset_probs = slim.fully_connected(onset_outputs,
                                               constants.MIDI_PITCHES,
                                               activation_fn=tf.sigmoid,
                                               scope='onset_probs')

            # onset_probs_flat is used during inference.
            onset_probs_flat = flatten_maybe_padded_sequences(
                onset_probs, length)
            if is_training:
                onset_labels_flat = flatten_maybe_padded_sequences(
                    onset_labels, length)
                onset_losses = tf_utils.log_loss(onset_labels_flat,
                                                 onset_probs_flat)
                tf.losses.add_loss(tf.reduce_mean(onset_losses))
                losses['onset'] = onset_losses
        with tf.variable_scope('offsets'):
            offset_outputs = acoustic_model(
                spec,
                hparams,
                lstm_units=hparams.offset_lstm_units,
                lengths=length,
                is_training=is_training)
            offset_probs = slim.fully_connected(offset_outputs,
                                                constants.MIDI_PITCHES,
                                                activation_fn=tf.sigmoid,
                                                scope='offset_probs')

            # offset_probs_flat is used during inference.
            offset_probs_flat = flatten_maybe_padded_sequences(
                offset_probs, length)
            if is_training:
                offset_labels_flat = flatten_maybe_padded_sequences(
                    offset_labels, length)
                offset_losses = tf_utils.log_loss(offset_labels_flat,
                                                  offset_probs_flat)
                tf.losses.add_loss(tf.reduce_mean(offset_losses))
                losses['offset'] = offset_losses
        with tf.variable_scope('velocity'):
            velocity_outputs = acoustic_model(
                spec,
                hparams,
                lstm_units=hparams.velocity_lstm_units,
                lengths=length,
                is_training=is_training)
            velocity_values = slim.fully_connected(velocity_outputs,
                                                   constants.MIDI_PITCHES,
                                                   activation_fn=None,
                                                   scope='onset_velocities')

            velocity_values_flat = flatten_maybe_padded_sequences(
                velocity_values, length)
            if is_training:
                velocity_labels_flat = flatten_maybe_padded_sequences(
                    velocity_labels, length)
                velocity_loss = tf.reduce_sum(
                    onset_labels_flat *
                    tf.square(velocity_labels_flat - velocity_values_flat),
                    axis=1)
                tf.losses.add_loss(tf.reduce_mean(velocity_loss))
                losses['velocity'] = velocity_loss

        with tf.variable_scope('frame'):
            if not hparams.share_conv_features:
                # TODO(eriche): this is broken when hparams.frame_lstm_units > 0
                activation_outputs = acoustic_model(
                    spec,
                    hparams,
                    lstm_units=hparams.frame_lstm_units,
                    lengths=length,
                    is_training=is_training)
                activation_probs = slim.fully_connected(
                    activation_outputs,
                    constants.MIDI_PITCHES,
                    activation_fn=tf.sigmoid,
                    scope='activation_probs')
            else:
                activation_probs = slim.fully_connected(
                    onset_outputs,
                    constants.MIDI_PITCHES,
                    activation_fn=tf.sigmoid,
                    scope='activation_probs')

            probs = []
            if hparams.stop_onset_gradient:
                probs.append(tf.stop_gradient(onset_probs))
            else:
                probs.append(onset_probs)

            if hparams.stop_activation_gradient:
                probs.append(tf.stop_gradient(activation_probs))
            else:
                probs.append(activation_probs)

            if hparams.stop_offset_gradient:
                probs.append(tf.stop_gradient(offset_probs))
            else:
                probs.append(offset_probs)

            combined_probs = tf.concat(probs, 2)

            if hparams.combined_lstm_units > 0:
                outputs = lstm_layer(
                    combined_probs,
                    hparams.batch_size,
                    hparams.combined_lstm_units,
                    lengths=length if hparams.use_lengths else None,
                    stack_size=hparams.combined_rnn_stack_size,
                    use_cudnn=hparams.use_cudnn,
                    is_training=is_training,
                    bidirectional=hparams.bidirectional)
            else:
                outputs = combined_probs

            frame_probs = slim.fully_connected(outputs,
                                               constants.MIDI_PITCHES,
                                               activation_fn=tf.sigmoid,
                                               scope='frame_probs')

        frame_probs_flat = flatten_maybe_padded_sequences(frame_probs, length)

        if is_training:
            frame_labels_flat = flatten_maybe_padded_sequences(
                frame_labels, length)
            frame_label_weights_flat = flatten_maybe_padded_sequences(
                frame_label_weights, length)
            if hparams.weight_frame_and_activation_loss:
                frame_loss_weights = frame_label_weights_flat
            else:
                frame_loss_weights = None
            frame_losses = tf_utils.log_loss(frame_labels_flat,
                                             frame_probs_flat,
                                             weights=frame_loss_weights)
            tf.losses.add_loss(tf.reduce_mean(frame_losses))
            losses['frame'] = frame_losses

            if hparams.activation_loss:
                if hparams.weight_frame_and_activation_loss:
                    activation_loss_weights = frame_label_weights
                else:
                    activation_loss_weights = None
                activation_losses = tf_utils.log_loss(
                    frame_labels_flat,
                    flatten_maybe_padded_sequences(activation_probs, length),
                    weights=activation_loss_weights)
                tf.losses.add_loss(tf.reduce_mean(activation_losses))
                losses['activation'] = activation_losses

    predictions = {
        'frame_probs_flat': frame_probs_flat,
        'onset_probs_flat': onset_probs_flat,
        'offset_probs_flat': offset_probs_flat,
        'velocity_values_flat': velocity_values_flat,
    }

    train_op = None
    loss = None
    if is_training:
        # Creates a pianoroll labels in red and probs in green [minibatch, 88]
        images = {}
        onset_pianorolls = tf.concat([
            onset_labels[:, :, :, tf.newaxis], onset_probs[:, :, :,
                                                           tf.newaxis],
            tf.zeros(tf.shape(onset_labels))[:, :, :, tf.newaxis]
        ],
                                     axis=3)
        images['OnsetPianorolls'] = onset_pianorolls
        offset_pianorolls = tf.concat([
            offset_labels[:, :, :, tf.newaxis], offset_probs[:, :, :,
                                                             tf.newaxis],
            tf.zeros(tf.shape(offset_labels))[:, :, :, tf.newaxis]
        ],
                                      axis=3)
        images['OffsetPianorolls'] = offset_pianorolls
        activation_pianorolls = tf.concat([
            frame_labels[:, :, :, tf.newaxis], frame_probs[:, :, :,
                                                           tf.newaxis],
            tf.zeros(tf.shape(frame_labels))[:, :, :, tf.newaxis]
        ],
                                          axis=3)
        images['ActivationPianorolls'] = activation_pianorolls
        for name, image in images.items():
            tf.summary.image(name, image)

        loss = tf.losses.get_total_loss()
        tf.summary.scalar('loss', loss)
        for label, loss_collection in losses.items():
            loss_label = 'losses/' + label
            tf.summary.scalar(loss_label, tf.reduce_mean(loss_collection))

        train_op = tf.contrib.layers.optimize_loss(
            name='training',
            loss=loss,
            global_step=tf.train.get_or_create_global_step(),
            learning_rate=hparams.learning_rate,
            learning_rate_decay_fn=functools.partial(
                tf.train.exponential_decay,
                decay_steps=hparams.decay_steps,
                decay_rate=hparams.decay_rate,
                staircase=True),
            clip_gradients=hparams.clip_norm,
            optimizer='Adam')

    return tf.estimator.EstimatorSpec(mode=mode,
                                      predictions=predictions,
                                      loss=loss,
                                      train_op=train_op)
Example #5
0
          spec,
          hparams,
          lstm_units=hparams.onset_lstm_units,
          lengths=length,
          is_training=is_training)
      onset_probs = slim.fully_connected(
          onset_outputs,
          constants.MIDI_PITCHES,
          activation_fn=tf.sigmoid,
          scope='onset_probs')

      # onset_probs_flat is used during inference.
      onset_probs_flat = flatten_maybe_padded_sequences(onset_probs, length)
      if is_training:
        onset_labels_flat = flatten_maybe_padded_sequences(onset_labels, length)
        onset_losses = tf_utils.log_loss(onset_labels_flat, onset_probs_flat)
        tf.losses.add_loss(tf.reduce_mean(onset_losses))
        losses['onset'] = onset_losses
    with tf.variable_scope('offsets'):
      offset_outputs = acoustic_model(
          spec,
          hparams,
          lstm_units=hparams.offset_lstm_units,
          lengths=length,
          is_training=is_training)
      offset_probs = slim.fully_connected(
          offset_outputs,
          constants.MIDI_PITCHES,
          activation_fn=tf.sigmoid,
          scope='offset_probs')
Example #6
0
def model_fn(features, labels, mode, params, config):
  """Builds the acoustic model."""
  del config
  hparams = params

  length = features.length
  spec = features.spec

  is_training = mode == tf.estimator.ModeKeys.TRAIN

  if is_training:
    onset_labels = labels.onsets
    offset_labels = labels.offsets
    velocity_labels = labels.velocities
    frame_labels = labels.labels
    frame_label_weights = labels.label_weights

  if hparams.stop_activation_gradient and not hparams.activation_loss:
    raise ValueError(
        'If stop_activation_gradient is true, activation_loss must be true.')

  losses = {}
  with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training):
    with tf.variable_scope('onsets'):
      onset_outputs = acoustic_model(
          spec,
          hparams,
          lstm_units=hparams.onset_lstm_units,
          lengths=length,
          is_training=is_training)
      onset_probs = slim.fully_connected(
          onset_outputs,
          constants.MIDI_PITCHES,
          activation_fn=tf.sigmoid,
          scope='onset_probs')

      # onset_probs_flat is used during inference.
      onset_probs_flat = flatten_maybe_padded_sequences(onset_probs, length)
      if is_training:
        onset_labels_flat = flatten_maybe_padded_sequences(onset_labels, length)
        onset_losses = tf_utils.log_loss(onset_labels_flat, onset_probs_flat)
        tf.losses.add_loss(tf.reduce_mean(onset_losses))
        losses['onset'] = onset_losses
    with tf.variable_scope('offsets'):
      offset_outputs = acoustic_model(
          spec,
          hparams,
          lstm_units=hparams.offset_lstm_units,
          lengths=length,
          is_training=is_training)
      offset_probs = slim.fully_connected(
          offset_outputs,
          constants.MIDI_PITCHES,
          activation_fn=tf.sigmoid,
          scope='offset_probs')

      # offset_probs_flat is used during inference.
      offset_probs_flat = flatten_maybe_padded_sequences(offset_probs, length)
      if is_training:
        offset_labels_flat = flatten_maybe_padded_sequences(
            offset_labels, length)
        offset_losses = tf_utils.log_loss(offset_labels_flat, offset_probs_flat)
        tf.losses.add_loss(tf.reduce_mean(offset_losses))
        losses['offset'] = offset_losses
    with tf.variable_scope('velocity'):
      velocity_outputs = acoustic_model(
          spec,
          hparams,
          lstm_units=hparams.velocity_lstm_units,
          lengths=length,
          is_training=is_training)
      velocity_values = slim.fully_connected(
          velocity_outputs,
          constants.MIDI_PITCHES,
          activation_fn=None,
          scope='onset_velocities')

      velocity_values_flat = flatten_maybe_padded_sequences(
          velocity_values, length)
      if is_training:
        velocity_labels_flat = flatten_maybe_padded_sequences(
            velocity_labels, length)
        velocity_loss = tf.reduce_sum(
            onset_labels_flat *
            tf.square(velocity_labels_flat - velocity_values_flat),
            axis=1)
        tf.losses.add_loss(tf.reduce_mean(velocity_loss))
        losses['velocity'] = velocity_loss

    with tf.variable_scope('frame'):
      if not hparams.share_conv_features:
        # TODO(eriche): this is broken when hparams.frame_lstm_units > 0
        activation_outputs = acoustic_model(
            spec,
            hparams,
            lstm_units=hparams.frame_lstm_units,
            lengths=length,
            is_training=is_training)
        activation_probs = slim.fully_connected(
            activation_outputs,
            constants.MIDI_PITCHES,
            activation_fn=tf.sigmoid,
            scope='activation_probs')
      else:
        activation_probs = slim.fully_connected(
            onset_outputs,
            constants.MIDI_PITCHES,
            activation_fn=tf.sigmoid,
            scope='activation_probs')

      probs = []
      if hparams.stop_onset_gradient:
        probs.append(tf.stop_gradient(onset_probs))
      else:
        probs.append(onset_probs)

      if hparams.stop_activation_gradient:
        probs.append(tf.stop_gradient(activation_probs))
      else:
        probs.append(activation_probs)

      if hparams.stop_offset_gradient:
        probs.append(tf.stop_gradient(offset_probs))
      else:
        probs.append(offset_probs)

      combined_probs = tf.concat(probs, 2)

      if hparams.combined_lstm_units > 0:
        outputs = lstm_layer(
            combined_probs,
            hparams.batch_size,
            hparams.combined_lstm_units,
            lengths=length if hparams.use_lengths else None,
            stack_size=hparams.combined_rnn_stack_size,
            use_cudnn=hparams.use_cudnn,
            is_training=is_training,
            bidirectional=hparams.bidirectional)
      else:
        outputs = combined_probs

      frame_probs = slim.fully_connected(
          outputs,
          constants.MIDI_PITCHES,
          activation_fn=tf.sigmoid,
          scope='frame_probs')

    frame_probs_flat = flatten_maybe_padded_sequences(frame_probs, length)

    if is_training:
      frame_labels_flat = flatten_maybe_padded_sequences(frame_labels, length)
      frame_label_weights_flat = flatten_maybe_padded_sequences(
          frame_label_weights, length)
      if hparams.weight_frame_and_activation_loss:
        frame_loss_weights = frame_label_weights_flat
      else:
        frame_loss_weights = None
      frame_losses = tf_utils.log_loss(
          frame_labels_flat, frame_probs_flat, weights=frame_loss_weights)
      tf.losses.add_loss(tf.reduce_mean(frame_losses))
      losses['frame'] = frame_losses

      if hparams.activation_loss:
        if hparams.weight_frame_and_activation_loss:
          activation_loss_weights = frame_label_weights
        else:
          activation_loss_weights = None
        activation_losses = tf_utils.log_loss(
            frame_labels_flat,
            flatten_maybe_padded_sequences(activation_probs, length),
            weights=activation_loss_weights)
        tf.losses.add_loss(tf.reduce_mean(activation_losses))
        losses['activation'] = activation_losses

  predictions = {
      'frame_probs_flat': frame_probs_flat,
      'onset_probs_flat': onset_probs_flat,
      'offset_probs_flat': offset_probs_flat,
      'velocity_values_flat': velocity_values_flat,
  }

  train_op = None
  loss = None
  if is_training:
    # Creates a pianoroll labels in red and probs in green [minibatch, 88]
    images = {}
    onset_pianorolls = tf.concat([
        onset_labels[:, :, :, tf.newaxis], onset_probs[:, :, :, tf.newaxis],
        tf.zeros(tf.shape(onset_labels))[:, :, :, tf.newaxis]
    ],
                                 axis=3)
    images['OnsetPianorolls'] = onset_pianorolls
    offset_pianorolls = tf.concat([
        offset_labels[:, :, :, tf.newaxis], offset_probs[:, :, :, tf.newaxis],
        tf.zeros(tf.shape(offset_labels))[:, :, :, tf.newaxis]
    ],
                                  axis=3)
    images['OffsetPianorolls'] = offset_pianorolls
    activation_pianorolls = tf.concat([
        frame_labels[:, :, :, tf.newaxis], frame_probs[:, :, :, tf.newaxis],
        tf.zeros(tf.shape(frame_labels))[:, :, :, tf.newaxis]
    ],
                                      axis=3)
    images['ActivationPianorolls'] = activation_pianorolls
    for name, image in images.items():
      tf.summary.image(name, image)

    loss = tf.losses.get_total_loss()
    tf.summary.scalar('loss', loss)
    for label, loss_collection in losses.items():
      loss_label = 'losses/' + label
      tf.summary.scalar(loss_label, tf.reduce_mean(loss_collection))

    train_op = tf.contrib.layers.optimize_loss(
        name='training',
        loss=loss,
        global_step=tf.train.get_or_create_global_step(),
        learning_rate=hparams.learning_rate,
        learning_rate_decay_fn=functools.partial(
            tf.train.exponential_decay,
            decay_steps=hparams.decay_steps,
            decay_rate=hparams.decay_rate,
            staircase=True),
        clip_gradients=hparams.clip_norm,
        optimizer='Adam')

  return tf.estimator.EstimatorSpec(
      mode=mode, predictions=predictions, loss=loss, train_op=train_op)