def model_fn(features, labels, mode, params, config): """Builds the acoustic model.""" del config hparams = params length = features.length spec = features.spec is_training = mode == tf_estimator.ModeKeys.TRAIN if is_training: onset_labels = labels.onsets offset_labels = labels.offsets velocity_labels = labels.velocities frame_labels = labels.labels frame_label_weights = labels.label_weights if hparams.stop_activation_gradient and not hparams.activation_loss: raise ValueError( 'If stop_activation_gradient is true, activation_loss must be true.' ) losses = {} with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training): with tf.variable_scope('onsets'): onset_outputs = acoustic_model(spec, hparams, lstm_units=hparams.onset_lstm_units, lengths=length) onset_probs = slim.fully_connected(onset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='onset_probs') # onset_probs_flat is used during inference. onset_probs_flat = flatten_maybe_padded_sequences( onset_probs, length) if is_training: onset_labels_flat = flatten_maybe_padded_sequences( onset_labels, length) onset_losses = tf_utils.log_loss(onset_labels_flat, onset_probs_flat) tf.losses.add_loss(tf.reduce_mean(onset_losses)) losses['onset'] = onset_losses with tf.variable_scope('offsets'): offset_outputs = acoustic_model( spec, hparams, lstm_units=hparams.offset_lstm_units, lengths=length) offset_probs = slim.fully_connected(offset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='offset_probs') # offset_probs_flat is used during inference. offset_probs_flat = flatten_maybe_padded_sequences( offset_probs, length) if is_training: offset_labels_flat = flatten_maybe_padded_sequences( offset_labels, length) offset_losses = tf_utils.log_loss(offset_labels_flat, offset_probs_flat) tf.losses.add_loss(tf.reduce_mean(offset_losses)) losses['offset'] = offset_losses with tf.variable_scope('velocity'): velocity_outputs = acoustic_model( spec, hparams, lstm_units=hparams.velocity_lstm_units, lengths=length) velocity_values = slim.fully_connected(velocity_outputs, constants.MIDI_PITCHES, activation_fn=None, scope='onset_velocities') velocity_values_flat = flatten_maybe_padded_sequences( velocity_values, length) if is_training: velocity_labels_flat = flatten_maybe_padded_sequences( velocity_labels, length) velocity_loss = tf.reduce_sum( onset_labels_flat * tf.square(velocity_labels_flat - velocity_values_flat), axis=1) tf.losses.add_loss(tf.reduce_mean(velocity_loss)) losses['velocity'] = velocity_loss with tf.variable_scope('frame'): if not hparams.share_conv_features: # TODO(eriche): this is broken when hparams.frame_lstm_units > 0 activation_outputs = acoustic_model( spec, hparams, lstm_units=hparams.frame_lstm_units, lengths=length) activation_probs = slim.fully_connected( activation_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='activation_probs') else: activation_probs = slim.fully_connected( onset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='activation_probs') probs = [] if hparams.stop_onset_gradient: probs.append(tf.stop_gradient(onset_probs)) else: probs.append(onset_probs) if hparams.stop_activation_gradient: probs.append(tf.stop_gradient(activation_probs)) else: probs.append(activation_probs) if hparams.stop_offset_gradient: probs.append(tf.stop_gradient(offset_probs)) else: probs.append(offset_probs) combined_probs = tf.concat(probs, 2) if hparams.combined_lstm_units > 0: outputs = lstm_layer( combined_probs, hparams.combined_lstm_units, lengths=length if hparams.use_lengths else None, stack_size=hparams.combined_rnn_stack_size, use_cudnn=hparams.use_cudnn, bidirectional=hparams.bidirectional) else: outputs = combined_probs frame_probs = slim.fully_connected(outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='frame_probs') frame_probs_flat = flatten_maybe_padded_sequences(frame_probs, length) if is_training: frame_labels_flat = flatten_maybe_padded_sequences( frame_labels, length) frame_label_weights_flat = flatten_maybe_padded_sequences( frame_label_weights, length) if hparams.weight_frame_and_activation_loss: frame_loss_weights = frame_label_weights_flat else: frame_loss_weights = None frame_losses = tf_utils.log_loss(frame_labels_flat, frame_probs_flat, weights=frame_loss_weights) tf.losses.add_loss(tf.reduce_mean(frame_losses)) losses['frame'] = frame_losses if hparams.activation_loss: if hparams.weight_frame_and_activation_loss: activation_loss_weights = frame_label_weights else: activation_loss_weights = None activation_losses = tf_utils.log_loss( frame_labels_flat, flatten_maybe_padded_sequences(activation_probs, length), weights=activation_loss_weights) tf.losses.add_loss(tf.reduce_mean(activation_losses)) losses['activation'] = activation_losses frame_predictions = frame_probs_flat > hparams.predict_frame_threshold onset_predictions = onset_probs_flat > hparams.predict_onset_threshold offset_predictions = offset_probs_flat > hparams.predict_offset_threshold frame_predictions = tf.expand_dims(frame_predictions, axis=0) onset_predictions = tf.expand_dims(onset_predictions, axis=0) offset_predictions = tf.expand_dims(offset_predictions, axis=0) velocity_values = tf.expand_dims(velocity_values_flat, axis=0) metrics_values = metrics.define_metrics( frame_probs=frame_probs, onset_probs=onset_probs, frame_predictions=frame_predictions, onset_predictions=onset_predictions, offset_predictions=offset_predictions, velocity_values=velocity_values, length=features.length, sequence_label=labels.note_sequence, frame_labels=labels.labels, sequence_id=features.sequence_id, hparams=hparams) for label, loss_collection in losses.items(): loss_label = 'losses/' + label metrics_values[loss_label] = loss_collection def predict_sequence(): """Convert frame predictions into a sequence (TF).""" def _predict(frame_probs, onset_probs, frame_predictions, onset_predictions, offset_predictions, velocity_values): """Convert frame predictions into a sequence (Python).""" sequence = infer_util.predict_sequence( frame_probs=frame_probs, onset_probs=onset_probs, frame_predictions=frame_predictions, onset_predictions=onset_predictions, offset_predictions=offset_predictions, velocity_values=velocity_values, hparams=hparams, min_pitch=constants.MIN_MIDI_PITCH) return sequence.SerializeToString() sequence = tf.py_func(_predict, inp=[ frame_probs[0], onset_probs[0], frame_predictions[0], onset_predictions[0], offset_predictions[0], velocity_values[0], ], Tout=tf.string, stateful=False) sequence.set_shape([]) return tf.expand_dims(sequence, axis=0) predictions = { 'frame_probs': frame_probs, 'onset_probs': onset_probs, 'frame_predictions': frame_predictions, 'onset_predictions': onset_predictions, 'offset_predictions': offset_predictions, 'velocity_values': velocity_values, 'sequence_predictions': predict_sequence(), # Include some features and labels in output because Estimator 'predict' # API does not give access to them. 'sequence_ids': features.sequence_id, 'sequence_labels': labels.note_sequence, 'frame_labels': labels.labels, 'onset_labels': labels.onsets, } for k, v in metrics_values.items(): predictions[k] = tf.stack(v) metric_ops = {k: tf.metrics.mean(v) for k, v in metrics_values.items()} train_op = None loss = None if is_training: # Creates a pianoroll labels in red and probs in green [minibatch, 88] images = {} onset_pianorolls = tf.concat([ onset_labels[:, :, :, tf.newaxis], onset_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(onset_labels))[:, :, :, tf.newaxis] ], axis=3) images['OnsetPianorolls'] = onset_pianorolls offset_pianorolls = tf.concat([ offset_labels[:, :, :, tf.newaxis], offset_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(offset_labels))[:, :, :, tf.newaxis] ], axis=3) images['OffsetPianorolls'] = offset_pianorolls activation_pianorolls = tf.concat([ frame_labels[:, :, :, tf.newaxis], frame_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(frame_labels))[:, :, :, tf.newaxis] ], axis=3) images['ActivationPianorolls'] = activation_pianorolls for name, image in images.items(): tf.summary.image(name, image) loss = tf.losses.get_total_loss() tf.summary.scalar('loss', loss) for label, loss_collection in losses.items(): loss_label = 'losses/' + label tf.summary.scalar(loss_label, tf.reduce_mean(loss_collection)) train_op = slim.optimize_loss( name='training', loss=loss, global_step=tf.train.get_or_create_global_step(), learning_rate=hparams.learning_rate, learning_rate_decay_fn=functools.partial( tf.train.exponential_decay, decay_steps=hparams.decay_steps, decay_rate=hparams.decay_rate, staircase=True), clip_gradients=hparams.clip_norm, optimizer='Adam') return tf_estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=metric_ops)
def get_model(transcription_data, hparams, is_training=True): """Builds the acoustic model.""" onset_labels = transcription_data.onsets velocity_labels = transcription_data.velocities frame_labels = transcription_data.labels frame_label_weights = transcription_data.label_weights lengths = transcription_data.lengths spec = transcription_data.spec if hparams.stop_activation_gradient and not hparams.activation_loss: raise ValueError( 'If stop_activation_gradient is true, activation_loss must be true.') losses = {} with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training): with tf.variable_scope('onsets'): onset_outputs = acoustic_model( spec, hparams, lstm_units=hparams.onset_lstm_units, lengths=lengths) onset_probs = slim.fully_connected( onset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='onset_probs') # onset_probs_flat is used during inference. onset_probs_flat = flatten_maybe_padded_sequences(onset_probs, lengths) onset_labels_flat = flatten_maybe_padded_sequences(onset_labels, lengths) tf.identity(onset_probs_flat, name='onset_probs_flat') tf.identity(onset_labels_flat, name='onset_labels_flat') tf.identity( tf.cast(tf.greater_equal(onset_probs_flat, .5), tf.float32), name='onset_predictions_flat') onset_losses = tf_utils.log_loss(onset_labels_flat, onset_probs_flat) tf.losses.add_loss(tf.reduce_mean(onset_losses)) losses['onset'] = onset_losses with tf.variable_scope('velocity'): # TODO(eriche): this is broken when hparams.velocity_lstm_units > 0 velocity_outputs = acoustic_model( spec, hparams, lstm_units=hparams.velocity_lstm_units, lengths=lengths) velocity_values = slim.fully_connected( velocity_outputs, constants.MIDI_PITCHES, activation_fn=None, scope='onset_velocities') velocity_values_flat = flatten_maybe_padded_sequences( velocity_values, lengths) tf.identity(velocity_values_flat, name='velocity_values_flat') velocity_labels_flat = flatten_maybe_padded_sequences( velocity_labels, lengths) velocity_loss = tf.reduce_sum( onset_labels_flat * tf.square(velocity_labels_flat - velocity_values_flat), axis=1) tf.losses.add_loss(tf.reduce_mean(velocity_loss)) losses['velocity'] = velocity_loss with tf.variable_scope('frame'): if not hparams.share_conv_features: # TODO(eriche): this is broken when hparams.frame_lstm_units > 0 activation_outputs = acoustic_model( spec, hparams, lstm_units=hparams.frame_lstm_units, lengths=lengths) activation_probs = slim.fully_connected( activation_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='activation_probs') else: activation_probs = slim.fully_connected( onset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='activation_probs') combined_probs = tf.concat([ tf.stop_gradient(onset_probs) if hparams.stop_onset_gradient else onset_probs, tf.stop_gradient(activation_probs) if hparams.stop_activation_gradient else activation_probs ], 2) if hparams.combined_lstm_units > 0: rnn_cell_fw = tf.contrib.rnn.LSTMBlockCell(hparams.combined_lstm_units) if hparams.frame_bidirectional: rnn_cell_bw = tf.contrib.rnn.LSTMBlockCell( hparams.combined_lstm_units) outputs, unused_output_states = tf.nn.bidirectional_dynamic_rnn( rnn_cell_fw, rnn_cell_bw, inputs=combined_probs, dtype=tf.float32) combined_outputs = tf.concat(outputs, 2) else: combined_outputs, unused_output_states = tf.nn.dynamic_rnn( rnn_cell_fw, inputs=combined_probs, dtype=tf.float32) else: combined_outputs = combined_probs frame_probs = slim.fully_connected( combined_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='frame_probs') frame_labels_flat = flatten_maybe_padded_sequences(frame_labels, lengths) frame_probs_flat = flatten_maybe_padded_sequences(frame_probs, lengths) tf.identity(frame_probs_flat, name='frame_probs_flat') frame_label_weights_flat = flatten_maybe_padded_sequences( frame_label_weights, lengths) frame_losses = tf_utils.log_loss( frame_labels_flat, frame_probs_flat, weights=frame_label_weights_flat if hparams.weight_frame_and_activation_loss else None) tf.losses.add_loss(tf.reduce_mean(frame_losses)) losses['frame'] = frame_losses if hparams.activation_loss: activation_losses = tf_utils.log_loss( frame_labels_flat, flatten_maybe_padded_sequences(activation_probs, lengths), weights=frame_label_weights_flat if hparams.weight_frame_and_activation_loss else None) tf.losses.add_loss(tf.reduce_mean(activation_losses)) losses['activation'] = activation_losses predictions_flat = tf.cast(tf.greater_equal(frame_probs_flat, .5), tf.float32) # Creates a pianoroll labels in red and probs in green [minibatch, 88] images = {} onset_pianorolls = tf.concat( [ onset_labels[:, :, :, tf.newaxis], onset_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(onset_labels))[:, :, :, tf.newaxis] ], axis=3) images['OnsetPianorolls'] = onset_pianorolls activation_pianorolls = tf.concat( [ frame_labels[:, :, :, tf.newaxis], frame_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(frame_labels))[:, :, :, tf.newaxis] ], axis=3) images['ActivationPianorolls'] = activation_pianorolls return (tf.losses.get_total_loss(), losses, frame_labels_flat, predictions_flat, images)
def get_model(transcription_data, hparams, is_training=True): """Builds the acoustic model.""" onset_labels = transcription_data.onsets velocity_labels = transcription_data.velocities frame_labels = transcription_data.labels frame_label_weights = transcription_data.label_weights lengths = transcription_data.lengths spec = transcription_data.spec if hparams.stop_activation_gradient and not hparams.activation_loss: raise ValueError( 'If stop_activation_gradient is true, activation_loss must be true.' ) losses = {} with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training): with tf.variable_scope('onsets'): onset_outputs = acoustic_model(spec, hparams, lstm_units=hparams.onset_lstm_units, lengths=lengths) onset_probs = slim.fully_connected(onset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='onset_probs') # onset_probs_flat is used during inference. onset_probs_flat = flatten_maybe_padded_sequences( onset_probs, lengths) onset_labels_flat = flatten_maybe_padded_sequences( onset_labels, lengths) tf.identity(onset_probs_flat, name='onset_probs_flat') tf.identity(onset_labels_flat, name='onset_labels_flat') tf.identity(tf.cast(tf.greater_equal(onset_probs_flat, .5), tf.float32), name='onset_predictions_flat') onset_losses = tf_utils.log_loss(onset_labels_flat, onset_probs_flat) tf.losses.add_loss(tf.reduce_mean(onset_losses)) losses['onset'] = onset_losses with tf.variable_scope('velocity'): # TODO(eriche): this is broken when hparams.velocity_lstm_units > 0 velocity_outputs = acoustic_model( spec, hparams, lstm_units=hparams.velocity_lstm_units, lengths=lengths) velocity_values = slim.fully_connected(velocity_outputs, constants.MIDI_PITCHES, activation_fn=None, scope='onset_velocities') velocity_values_flat = flatten_maybe_padded_sequences( velocity_values, lengths) tf.identity(velocity_values_flat, name='velocity_values_flat') velocity_labels_flat = flatten_maybe_padded_sequences( velocity_labels, lengths) velocity_loss = tf.reduce_sum( onset_labels_flat * tf.square(velocity_labels_flat - velocity_values_flat), axis=1) tf.losses.add_loss(tf.reduce_mean(velocity_loss)) losses['velocity'] = velocity_loss with tf.variable_scope('frame'): if not hparams.share_conv_features: # TODO(eriche): this is broken when hparams.frame_lstm_units > 0 activation_outputs = acoustic_model( spec, hparams, lstm_units=hparams.frame_lstm_units, lengths=lengths) activation_probs = slim.fully_connected( activation_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='activation_probs') else: activation_probs = slim.fully_connected( onset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='activation_probs') combined_probs = tf.concat([ tf.stop_gradient(onset_probs) if hparams.stop_onset_gradient else onset_probs, tf.stop_gradient(activation_probs) if hparams.stop_activation_gradient else activation_probs ], 2) if hparams.combined_lstm_units > 0: rnn_cell_fw = tf.contrib.rnn.LSTMBlockCell( hparams.combined_lstm_units) if hparams.frame_bidirectional: rnn_cell_bw = tf.contrib.rnn.LSTMBlockCell( hparams.combined_lstm_units) outputs, unused_output_states = tf.nn.bidirectional_dynamic_rnn( rnn_cell_fw, rnn_cell_bw, inputs=combined_probs, dtype=tf.float32) combined_outputs = tf.concat(outputs, 2) else: combined_outputs, unused_output_states = tf.nn.dynamic_rnn( rnn_cell_fw, inputs=combined_probs, dtype=tf.float32) else: combined_outputs = combined_probs frame_probs = slim.fully_connected(combined_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='frame_probs') frame_labels_flat = flatten_maybe_padded_sequences( frame_labels, lengths) frame_probs_flat = flatten_maybe_padded_sequences(frame_probs, lengths) tf.identity(frame_probs_flat, name='frame_probs_flat') frame_label_weights_flat = flatten_maybe_padded_sequences( frame_label_weights, lengths) frame_losses = tf_utils.log_loss( frame_labels_flat, frame_probs_flat, weights=frame_label_weights_flat if hparams.weight_frame_and_activation_loss else None) tf.losses.add_loss(tf.reduce_mean(frame_losses)) losses['frame'] = frame_losses if hparams.activation_loss: activation_losses = tf_utils.log_loss( frame_labels_flat, flatten_maybe_padded_sequences(activation_probs, lengths), weights=frame_label_weights_flat if hparams.weight_frame_and_activation_loss else None) tf.losses.add_loss(tf.reduce_mean(activation_losses)) losses['activation'] = activation_losses predictions_flat = tf.cast(tf.greater_equal(frame_probs_flat, .5), tf.float32) # Creates a pianoroll labels in red and probs in green [minibatch, 88] images = {} onset_pianorolls = tf.concat([ onset_labels[:, :, :, tf.newaxis], onset_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(onset_labels))[:, :, :, tf.newaxis] ], axis=3) images['OnsetPianorolls'] = onset_pianorolls activation_pianorolls = tf.concat([ frame_labels[:, :, :, tf.newaxis], frame_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(frame_labels))[:, :, :, tf.newaxis] ], axis=3) images['ActivationPianorolls'] = activation_pianorolls return (tf.losses.get_total_loss(), losses, frame_labels_flat, predictions_flat, images)
def model_fn(features, labels, mode, params, config): """Builds the acoustic model.""" del config hparams = params length = features.length spec = features.spec is_training = mode == tf.estimator.ModeKeys.TRAIN if is_training: onset_labels = labels.onsets offset_labels = labels.offsets velocity_labels = labels.velocities frame_labels = labels.labels frame_label_weights = labels.label_weights if hparams.stop_activation_gradient and not hparams.activation_loss: raise ValueError( 'If stop_activation_gradient is true, activation_loss must be true.' ) losses = {} with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training): with tf.variable_scope('onsets'): onset_outputs = acoustic_model(spec, hparams, lstm_units=hparams.onset_lstm_units, lengths=length, is_training=is_training) onset_probs = slim.fully_connected(onset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='onset_probs') # onset_probs_flat is used during inference. onset_probs_flat = flatten_maybe_padded_sequences( onset_probs, length) if is_training: onset_labels_flat = flatten_maybe_padded_sequences( onset_labels, length) onset_losses = tf_utils.log_loss(onset_labels_flat, onset_probs_flat) tf.losses.add_loss(tf.reduce_mean(onset_losses)) losses['onset'] = onset_losses with tf.variable_scope('offsets'): offset_outputs = acoustic_model( spec, hparams, lstm_units=hparams.offset_lstm_units, lengths=length, is_training=is_training) offset_probs = slim.fully_connected(offset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='offset_probs') # offset_probs_flat is used during inference. offset_probs_flat = flatten_maybe_padded_sequences( offset_probs, length) if is_training: offset_labels_flat = flatten_maybe_padded_sequences( offset_labels, length) offset_losses = tf_utils.log_loss(offset_labels_flat, offset_probs_flat) tf.losses.add_loss(tf.reduce_mean(offset_losses)) losses['offset'] = offset_losses with tf.variable_scope('velocity'): velocity_outputs = acoustic_model( spec, hparams, lstm_units=hparams.velocity_lstm_units, lengths=length, is_training=is_training) velocity_values = slim.fully_connected(velocity_outputs, constants.MIDI_PITCHES, activation_fn=None, scope='onset_velocities') velocity_values_flat = flatten_maybe_padded_sequences( velocity_values, length) if is_training: velocity_labels_flat = flatten_maybe_padded_sequences( velocity_labels, length) velocity_loss = tf.reduce_sum( onset_labels_flat * tf.square(velocity_labels_flat - velocity_values_flat), axis=1) tf.losses.add_loss(tf.reduce_mean(velocity_loss)) losses['velocity'] = velocity_loss with tf.variable_scope('frame'): if not hparams.share_conv_features: # TODO(eriche): this is broken when hparams.frame_lstm_units > 0 activation_outputs = acoustic_model( spec, hparams, lstm_units=hparams.frame_lstm_units, lengths=length, is_training=is_training) activation_probs = slim.fully_connected( activation_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='activation_probs') else: activation_probs = slim.fully_connected( onset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='activation_probs') probs = [] if hparams.stop_onset_gradient: probs.append(tf.stop_gradient(onset_probs)) else: probs.append(onset_probs) if hparams.stop_activation_gradient: probs.append(tf.stop_gradient(activation_probs)) else: probs.append(activation_probs) if hparams.stop_offset_gradient: probs.append(tf.stop_gradient(offset_probs)) else: probs.append(offset_probs) combined_probs = tf.concat(probs, 2) if hparams.combined_lstm_units > 0: outputs = lstm_layer( combined_probs, hparams.batch_size, hparams.combined_lstm_units, lengths=length if hparams.use_lengths else None, stack_size=hparams.combined_rnn_stack_size, use_cudnn=hparams.use_cudnn, is_training=is_training, bidirectional=hparams.bidirectional) else: outputs = combined_probs frame_probs = slim.fully_connected(outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='frame_probs') frame_probs_flat = flatten_maybe_padded_sequences(frame_probs, length) if is_training: frame_labels_flat = flatten_maybe_padded_sequences( frame_labels, length) frame_label_weights_flat = flatten_maybe_padded_sequences( frame_label_weights, length) if hparams.weight_frame_and_activation_loss: frame_loss_weights = frame_label_weights_flat else: frame_loss_weights = None frame_losses = tf_utils.log_loss(frame_labels_flat, frame_probs_flat, weights=frame_loss_weights) tf.losses.add_loss(tf.reduce_mean(frame_losses)) losses['frame'] = frame_losses if hparams.activation_loss: if hparams.weight_frame_and_activation_loss: activation_loss_weights = frame_label_weights else: activation_loss_weights = None activation_losses = tf_utils.log_loss( frame_labels_flat, flatten_maybe_padded_sequences(activation_probs, length), weights=activation_loss_weights) tf.losses.add_loss(tf.reduce_mean(activation_losses)) losses['activation'] = activation_losses predictions = { 'frame_probs_flat': frame_probs_flat, 'onset_probs_flat': onset_probs_flat, 'offset_probs_flat': offset_probs_flat, 'velocity_values_flat': velocity_values_flat, } train_op = None loss = None if is_training: # Creates a pianoroll labels in red and probs in green [minibatch, 88] images = {} onset_pianorolls = tf.concat([ onset_labels[:, :, :, tf.newaxis], onset_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(onset_labels))[:, :, :, tf.newaxis] ], axis=3) images['OnsetPianorolls'] = onset_pianorolls offset_pianorolls = tf.concat([ offset_labels[:, :, :, tf.newaxis], offset_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(offset_labels))[:, :, :, tf.newaxis] ], axis=3) images['OffsetPianorolls'] = offset_pianorolls activation_pianorolls = tf.concat([ frame_labels[:, :, :, tf.newaxis], frame_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(frame_labels))[:, :, :, tf.newaxis] ], axis=3) images['ActivationPianorolls'] = activation_pianorolls for name, image in images.items(): tf.summary.image(name, image) loss = tf.losses.get_total_loss() tf.summary.scalar('loss', loss) for label, loss_collection in losses.items(): loss_label = 'losses/' + label tf.summary.scalar(loss_label, tf.reduce_mean(loss_collection)) train_op = tf.contrib.layers.optimize_loss( name='training', loss=loss, global_step=tf.train.get_or_create_global_step(), learning_rate=hparams.learning_rate, learning_rate_decay_fn=functools.partial( tf.train.exponential_decay, decay_steps=hparams.decay_steps, decay_rate=hparams.decay_rate, staircase=True), clip_gradients=hparams.clip_norm, optimizer='Adam') return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, train_op=train_op)
spec, hparams, lstm_units=hparams.onset_lstm_units, lengths=length, is_training=is_training) onset_probs = slim.fully_connected( onset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='onset_probs') # onset_probs_flat is used during inference. onset_probs_flat = flatten_maybe_padded_sequences(onset_probs, length) if is_training: onset_labels_flat = flatten_maybe_padded_sequences(onset_labels, length) onset_losses = tf_utils.log_loss(onset_labels_flat, onset_probs_flat) tf.losses.add_loss(tf.reduce_mean(onset_losses)) losses['onset'] = onset_losses with tf.variable_scope('offsets'): offset_outputs = acoustic_model( spec, hparams, lstm_units=hparams.offset_lstm_units, lengths=length, is_training=is_training) offset_probs = slim.fully_connected( offset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='offset_probs')
def model_fn(features, labels, mode, params, config): """Builds the acoustic model.""" del config hparams = params length = features.length spec = features.spec is_training = mode == tf.estimator.ModeKeys.TRAIN if is_training: onset_labels = labels.onsets offset_labels = labels.offsets velocity_labels = labels.velocities frame_labels = labels.labels frame_label_weights = labels.label_weights if hparams.stop_activation_gradient and not hparams.activation_loss: raise ValueError( 'If stop_activation_gradient is true, activation_loss must be true.') losses = {} with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training): with tf.variable_scope('onsets'): onset_outputs = acoustic_model( spec, hparams, lstm_units=hparams.onset_lstm_units, lengths=length, is_training=is_training) onset_probs = slim.fully_connected( onset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='onset_probs') # onset_probs_flat is used during inference. onset_probs_flat = flatten_maybe_padded_sequences(onset_probs, length) if is_training: onset_labels_flat = flatten_maybe_padded_sequences(onset_labels, length) onset_losses = tf_utils.log_loss(onset_labels_flat, onset_probs_flat) tf.losses.add_loss(tf.reduce_mean(onset_losses)) losses['onset'] = onset_losses with tf.variable_scope('offsets'): offset_outputs = acoustic_model( spec, hparams, lstm_units=hparams.offset_lstm_units, lengths=length, is_training=is_training) offset_probs = slim.fully_connected( offset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='offset_probs') # offset_probs_flat is used during inference. offset_probs_flat = flatten_maybe_padded_sequences(offset_probs, length) if is_training: offset_labels_flat = flatten_maybe_padded_sequences( offset_labels, length) offset_losses = tf_utils.log_loss(offset_labels_flat, offset_probs_flat) tf.losses.add_loss(tf.reduce_mean(offset_losses)) losses['offset'] = offset_losses with tf.variable_scope('velocity'): velocity_outputs = acoustic_model( spec, hparams, lstm_units=hparams.velocity_lstm_units, lengths=length, is_training=is_training) velocity_values = slim.fully_connected( velocity_outputs, constants.MIDI_PITCHES, activation_fn=None, scope='onset_velocities') velocity_values_flat = flatten_maybe_padded_sequences( velocity_values, length) if is_training: velocity_labels_flat = flatten_maybe_padded_sequences( velocity_labels, length) velocity_loss = tf.reduce_sum( onset_labels_flat * tf.square(velocity_labels_flat - velocity_values_flat), axis=1) tf.losses.add_loss(tf.reduce_mean(velocity_loss)) losses['velocity'] = velocity_loss with tf.variable_scope('frame'): if not hparams.share_conv_features: # TODO(eriche): this is broken when hparams.frame_lstm_units > 0 activation_outputs = acoustic_model( spec, hparams, lstm_units=hparams.frame_lstm_units, lengths=length, is_training=is_training) activation_probs = slim.fully_connected( activation_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='activation_probs') else: activation_probs = slim.fully_connected( onset_outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='activation_probs') probs = [] if hparams.stop_onset_gradient: probs.append(tf.stop_gradient(onset_probs)) else: probs.append(onset_probs) if hparams.stop_activation_gradient: probs.append(tf.stop_gradient(activation_probs)) else: probs.append(activation_probs) if hparams.stop_offset_gradient: probs.append(tf.stop_gradient(offset_probs)) else: probs.append(offset_probs) combined_probs = tf.concat(probs, 2) if hparams.combined_lstm_units > 0: outputs = lstm_layer( combined_probs, hparams.batch_size, hparams.combined_lstm_units, lengths=length if hparams.use_lengths else None, stack_size=hparams.combined_rnn_stack_size, use_cudnn=hparams.use_cudnn, is_training=is_training, bidirectional=hparams.bidirectional) else: outputs = combined_probs frame_probs = slim.fully_connected( outputs, constants.MIDI_PITCHES, activation_fn=tf.sigmoid, scope='frame_probs') frame_probs_flat = flatten_maybe_padded_sequences(frame_probs, length) if is_training: frame_labels_flat = flatten_maybe_padded_sequences(frame_labels, length) frame_label_weights_flat = flatten_maybe_padded_sequences( frame_label_weights, length) if hparams.weight_frame_and_activation_loss: frame_loss_weights = frame_label_weights_flat else: frame_loss_weights = None frame_losses = tf_utils.log_loss( frame_labels_flat, frame_probs_flat, weights=frame_loss_weights) tf.losses.add_loss(tf.reduce_mean(frame_losses)) losses['frame'] = frame_losses if hparams.activation_loss: if hparams.weight_frame_and_activation_loss: activation_loss_weights = frame_label_weights else: activation_loss_weights = None activation_losses = tf_utils.log_loss( frame_labels_flat, flatten_maybe_padded_sequences(activation_probs, length), weights=activation_loss_weights) tf.losses.add_loss(tf.reduce_mean(activation_losses)) losses['activation'] = activation_losses predictions = { 'frame_probs_flat': frame_probs_flat, 'onset_probs_flat': onset_probs_flat, 'offset_probs_flat': offset_probs_flat, 'velocity_values_flat': velocity_values_flat, } train_op = None loss = None if is_training: # Creates a pianoroll labels in red and probs in green [minibatch, 88] images = {} onset_pianorolls = tf.concat([ onset_labels[:, :, :, tf.newaxis], onset_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(onset_labels))[:, :, :, tf.newaxis] ], axis=3) images['OnsetPianorolls'] = onset_pianorolls offset_pianorolls = tf.concat([ offset_labels[:, :, :, tf.newaxis], offset_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(offset_labels))[:, :, :, tf.newaxis] ], axis=3) images['OffsetPianorolls'] = offset_pianorolls activation_pianorolls = tf.concat([ frame_labels[:, :, :, tf.newaxis], frame_probs[:, :, :, tf.newaxis], tf.zeros(tf.shape(frame_labels))[:, :, :, tf.newaxis] ], axis=3) images['ActivationPianorolls'] = activation_pianorolls for name, image in images.items(): tf.summary.image(name, image) loss = tf.losses.get_total_loss() tf.summary.scalar('loss', loss) for label, loss_collection in losses.items(): loss_label = 'losses/' + label tf.summary.scalar(loss_label, tf.reduce_mean(loss_collection)) train_op = tf.contrib.layers.optimize_loss( name='training', loss=loss, global_step=tf.train.get_or_create_global_step(), learning_rate=hparams.learning_rate, learning_rate_decay_fn=functools.partial( tf.train.exponential_decay, decay_steps=hparams.decay_steps, decay_rate=hparams.decay_rate, staircase=True), clip_gradients=hparams.clip_norm, optimizer='Adam') return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op)