def get_metrics(features, labels, frame_probs, onset_probs, frame_predictions, onset_predictions, offset_predictions, velocity_values, hparams): """Return metrics values ops.""" if hparams.drums_only: return _drums_only_metric_ops( features=features, labels=labels, frame_probs=frame_probs, onset_probs=onset_probs, frame_predictions=frame_predictions, onset_predictions=onset_predictions, offset_predictions=offset_predictions, velocity_values=velocity_values, hparams=hparams) else: return metrics.define_metrics( frame_probs=frame_probs, onset_probs=onset_probs, frame_predictions=frame_predictions, onset_predictions=onset_predictions, offset_predictions=offset_predictions, velocity_values=velocity_values, length=features.length, sequence_label=labels.note_sequence, frame_labels=labels.labels, sequence_id=features.sequence_id, hparams=hparams)
def _drums_only_metric_ops(features, labels, frame_predictions, onset_predictions, offset_predictions, velocity_values, hparams): """Generate drum metrics: offsets/frames are ignored.""" del frame_predictions, offset_predictions # unused metric_ops = metrics.define_metrics( frame_predictions=onset_predictions, onset_predictions=onset_predictions, offset_predictions=onset_predictions, velocity_values=velocity_values, length=features.length, sequence_label=labels.note_sequence, frame_labels=labels.labels, sequence_id=features.sequence_id, hparams=hparams, min_pitch=constants.MIN_MIDI_PITCH, max_pitch=constants.MAX_MIDI_PITCH, prefix='drums/', onsets_only=True, pitch_map=drum_mappings.GROOVE_PITCH_NAMES) return metric_ops
def model_fn(features, labels, mode, params, config): """Builds the acoustic model.""" del config hparams = params length = features.length spec = features.spec is_training = mode == tf.estimator.ModeKeys.TRAIN if is_training: onset_labels = labels.onsets offset_labels = labels.offsets velocity_labels = labels.velocities frame_labels = labels.labels frame_label_weights = labels.label_weights if hparams.stop_activation_gradient and not hparams.activation_loss: raise ValueError( 'If stop_activation_gradient is true, activation_loss must be true.') losses = {} with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training): with tf.variable_scope('onsets'): onset_outputs = acoustic_model( spec, hparams, lstm_units=hparams.onset_lstm_units, lengths=length, is_training=is_training) onset_probs = slim.fully_connected( onset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='onset_probs') # onset_probs_flat is used during inference. onset_probs_flat = flatten_maybe_padded_sequences(onset_probs, length) if is_training: onset_labels_flat = flatten_maybe_padded_sequences(onset_labels, length) onset_losses = tf_utils.log_loss(onset_labels_flat, onset_probs_flat) tf.losses.add_loss(tf.reduce_mean(onset_losses)) losses['onset'] = onset_losses with tf.variable_scope('offsets'): offset_outputs = acoustic_model( spec, hparams, lstm_units=hparams.offset_lstm_units, lengths=length, is_training=is_training) offset_probs = slim.fully_connected( offset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='offset_probs') # offset_probs_flat is used during inference. offset_probs_flat = flatten_maybe_padded_sequences(offset_probs, length) if is_training: offset_labels_flat = flatten_maybe_padded_sequences( offset_labels, length) offset_losses = tf_utils.log_loss(offset_labels_flat, offset_probs_flat) tf.losses.add_loss(tf.reduce_mean(offset_losses)) losses['offset'] = offset_losses with tf.variable_scope('velocity'): velocity_outputs = acoustic_model( spec, hparams, lstm_units=hparams.velocity_lstm_units, lengths=length, is_training=is_training) velocity_values = slim.fully_connected( velocity_outputs, constants.MIDI_PITCHES, activation_fn=None, scope='onset_velocities') velocity_values_flat = flatten_maybe_padded_sequences( velocity_values, length) if is_training: velocity_labels_flat = flatten_maybe_padded_sequences( velocity_labels, length) velocity_loss = tf.reduce_sum( onset_labels_flat * tf.square(velocity_labels_flat - velocity_values_flat), axis=1) tf.losses.add_loss(tf.reduce_mean(velocity_loss)) losses['velocity'] = velocity_loss with tf.variable_scope('frame'): if not hparams.share_conv_features: # TODO(eriche): this is broken when hparams.frame_lstm_units > 0 activation_outputs = acoustic_model( spec, hparams, lstm_units=hparams.frame_lstm_units, lengths=length, is_training=is_training) activation_probs = slim.fully_connected( activation_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='activation_probs') else: activation_probs = slim.fully_connected( onset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='activation_probs') probs = [] if hparams.stop_onset_gradient: probs.append(tf.stop_gradient(onset_probs)) else: probs.append(onset_probs) if hparams.stop_activation_gradient: probs.append(tf.stop_gradient(activation_probs)) else: probs.append(activation_probs) if hparams.stop_offset_gradient: probs.append(tf.stop_gradient(offset_probs)) else: probs.append(offset_probs) combined_probs = tf.concat(probs, 2) if hparams.combined_lstm_units > 0: outputs = lstm_layer( combined_probs, hparams.batch_size, hparams.combined_lstm_units, lengths=length if hparams.use_lengths else None, stack_size=hparams.combined_rnn_stack_size, use_cudnn=hparams.use_cudnn, is_training=is_training, bidirectional=hparams.bidirectional) else: outputs = combined_probs frame_probs = slim.fully_connected( outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='frame_probs') frame_probs_flat = flatten_maybe_padded_sequences(frame_probs, length) if is_training: frame_labels_flat = flatten_maybe_padded_sequences(frame_labels, length) frame_label_weights_flat = flatten_maybe_padded_sequences( frame_label_weights, length) if hparams.weight_frame_and_activation_loss: frame_loss_weights = frame_label_weights_flat else: frame_loss_weights = None frame_losses = tf_utils.log_loss( frame_labels_flat, frame_probs_flat, weights=frame_loss_weights) tf.losses.add_loss(tf.reduce_mean(frame_losses)) losses['frame'] = frame_losses if hparams.activation_loss: if hparams.weight_frame_and_activation_loss: activation_loss_weights = frame_label_weights else: activation_loss_weights = None activation_losses = tf_utils.log_loss( frame_labels_flat, flatten_maybe_padded_sequences(activation_probs, length), weights=activation_loss_weights) tf.losses.add_loss(tf.reduce_mean(activation_losses)) losses['activation'] = activation_losses frame_predictions = frame_probs_flat > hparams.predict_frame_threshold onset_predictions = onset_probs_flat > hparams.predict_onset_threshold offset_predictions = offset_probs_flat > hparams.predict_offset_threshold frame_predictions = tf.expand_dims(frame_predictions, axis=0) onset_predictions = tf.expand_dims(onset_predictions, axis=0) offset_predictions = tf.expand_dims(offset_predictions, axis=0) velocity_values = tf.expand_dims(velocity_values_flat, axis=0) metrics_values = metrics.define_metrics( frame_predictions=frame_predictions, onset_predictions=onset_predictions, offset_predictions=offset_predictions, velocity_values=velocity_values, length=features.length, sequence_label=labels.note_sequence, frame_labels=labels.labels, sequence_id=features.sequence_id, hparams=hparams) for label, loss_collection in losses.items(): loss_label = 'losses/' + label metrics_values[loss_label] = loss_collection def predict_sequence(): """Convert frame predictions into a sequence.""" def _predict(frame_predictions, onset_predictions, offset_predictions, velocity_values): sequence = infer_util.predict_sequence( frame_predictions=frame_predictions, onset_predictions=onset_predictions, offset_predictions=offset_predictions, velocity_values=velocity_values, hparams=hparams, min_pitch=constants.MIN_MIDI_PITCH) return sequence.SerializeToString() sequence = tf.py_func( _predict, inp=[ frame_predictions[0], onset_predictions[0], offset_predictions[0], velocity_values[0], ], Tout=tf.string, stateful=False) sequence.set_shape([]) return tf.expand_dims(sequence, axis=0) predictions = { 'frame_probs': tf.expand_dims(frame_probs_flat, axis=0), 'frame_predictions': frame_predictions, 'onset_predictions': onset_predictions, 'offset_predictions': offset_predictions, 'velocity_values': velocity_values, 'sequence_predictions': predict_sequence(), # Include some features and labels in output because Estimator 'predict' # API does not give access to them. 'sequence_ids': features.sequence_id, 'sequence_labels': labels.note_sequence, 'frame_labels': labels.labels, } for k, v in metrics_values.items(): predictions[k] = tf.stack(v) metric_ops = {k: tf.metrics.mean(v) for k, v in metrics_values.items()} train_op = None loss = None if is_training: # Creates a pianoroll labels in red and probs in green [minibatch, 88] images = {} onset_pianorolls = tf.concat([ onset_labels[:, :, :, tf.newaxis], onset_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(onset_labels))[:, :, :, tf.newaxis] ], axis=3) images['OnsetPianorolls'] = onset_pianorolls offset_pianorolls = tf.concat([ offset_labels[:, :, :, tf.newaxis], offset_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(offset_labels))[:, :, :, tf.newaxis] ], axis=3) images['OffsetPianorolls'] = offset_pianorolls activation_pianorolls = tf.concat([ frame_labels[:, :, :, tf.newaxis], frame_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(frame_labels))[:, :, :, tf.newaxis] ], axis=3) images['ActivationPianorolls'] = activation_pianorolls for name, image in images.items(): tf.summary.image(name, image) loss = tf.losses.get_total_loss() tf.summary.scalar('loss', loss) for label, loss_collection in losses.items(): loss_label = 'losses/' + label tf.summary.scalar(loss_label, tf.reduce_mean(loss_collection)) train_op = contrib_layers.optimize_loss( name='training', loss=loss, global_step=tf.train.get_or_create_global_step(), learning_rate=hparams.learning_rate, learning_rate_decay_fn=functools.partial( tf.train.exponential_decay, decay_steps=hparams.decay_steps, decay_rate=hparams.decay_rate, staircase=True), clip_gradients=hparams.clip_norm, optimizer='Adam') return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=metric_ops)
losses['activation'] = activation_losses frame_predictions = frame_probs_flat > hparams.predict_frame_threshold onset_predictions = onset_probs_flat > hparams.predict_onset_threshold offset_predictions = offset_probs_flat > hparams.predict_offset_threshold frame_predictions = tf.expand_dims(frame_predictions, axis=0) onset_predictions = tf.expand_dims(onset_predictions, axis=0) offset_predictions = tf.expand_dims(offset_predictions, axis=0) velocity_values = tf.expand_dims(velocity_values_flat, axis=0) metrics_values = metrics.define_metrics( frame_predictions=frame_predictions, onset_predictions=onset_predictions, offset_predictions=offset_predictions, velocity_values=velocity_values, length=features.length, sequence_label=labels.note_sequence, frame_labels=labels.labels, sequence_id=features.sequence_id, hparams=hparams) for label, loss_collection in losses.items(): loss_label = 'losses/' + label metrics_values[loss_label] = loss_collection def predict_sequence(): """Convert frame predictions into a sequence.""" def _predict(frame_predictions, onset_predictions, offset_predictions, velocity_values): sequence = infer_util.predict_sequence( frame_predictions=frame_predictions,