コード例 #1
0
def eval_op(batch, hparams, config_name):
    """Define a evaluation op.

    Args:
      batch: Batch produced by NSynthReader.
      hparams: Hyperparameters.
      config_name: Name of config module.

    Returns:
      eval_op: A complete evaluation op with summaries.
    """
    phase = not (hparams.mag_only or hparams.raw_audio)

    config = utils.get_module("baseline.models.ae_configs.%s" % config_name)
    if hparams.raw_audio:
        x = batch["audio"]
        # Add height and channel dims
        x = tf.expand_dims(tf.expand_dims(x, 1), -1)
    else:
        x = batch["spectrogram"]

    # Define the model
    with tf.name_scope("Model"):
        z = config.encode(x, hparams, is_training=False)
        xhat = config.decode(z, batch, hparams, is_training=False)

    # For interpolation
    tf.add_to_collection("x", x)
    tf.add_to_collection("pitch", batch["pitch"])
    tf.add_to_collection("z", z)
    tf.add_to_collection("xhat", xhat)

    total_loss = compute_mse_loss(x, xhat, hparams)

    # Define the metrics:
    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
        "Loss":
        slim.metrics.mean(total_loss),
    })

    # Define the summaries
    for name, value in names_to_values.items():
        slim.summaries.add_scalar_summary(value, name, print_summary=True)

    # Interpolate
    with tf.name_scope("Interpolation"):
        xhat = config.decode(z, batch, hparams, reuse=True, is_training=False)

        # Linear interpolation
        z_shift_one_example = tf.concat([z[1:], z[:1]], 0)
        z_linear_half = (z + z_shift_one_example) / 2.0
        xhat_linear_half = config.decode(z_linear_half,
                                         batch,
                                         hparams,
                                         reuse=True,
                                         is_training=False)

        # Pitch shift

        pitch_plus_2 = tf.clip_by_value(batch["pitch"] + 2, 0, 127)
        pitch_minus_2 = tf.clip_by_value(batch["pitch"] - 2, 0, 127)

        batch["pitch"] = pitch_minus_2
        xhat_pitch_minus_2 = config.decode(z,
                                           batch,
                                           hparams,
                                           reuse=True,
                                           is_training=False)
        batch["pitch"] = pitch_plus_2
        xhat_pitch_plus_2 = config.decode(z,
                                          batch,
                                          hparams,
                                          reuse=True,
                                          is_training=False)

    utils.specgram_summaries(x, "Training Examples", hparams, phase=phase)
    utils.specgram_summaries(xhat, "Reconstructions", hparams, phase=phase)
    utils.specgram_summaries(x - xhat,
                             "Difference",
                             hparams,
                             audio=False,
                             phase=phase)
    utils.specgram_summaries(xhat_linear_half,
                             "Linear Interp. 0.5",
                             hparams,
                             phase=phase)
    utils.specgram_summaries(xhat_pitch_plus_2,
                             "Pitch +2",
                             hparams,
                             phase=phase)
    utils.specgram_summaries(xhat_pitch_minus_2,
                             "Pitch -2",
                             hparams,
                             phase=phase)

    return list(names_to_updates.values())
コード例 #2
0
ファイル: ae.py プロジェクト: Kaushikpatnaik/magenta
def eval_op(batch, hparams, config_name):
  """Define a evaluation op.

  Args:
    batch: Batch produced by NSynthReader.
    hparams: Hyperparameters.
    config_name: Name of config module.

  Returns:
    eval_op: A complete evaluation op with summaries.
  """
  phase = False if hparams.mag_only or hparams.raw_audio else True

  config = utils.get_module("baseline.models.ae_configs.%s" % config_name)
  if hparams.raw_audio:
    x = batch["audio"]
    # Add height and channel dims
    x = tf.expand_dims(tf.expand_dims(x, 1), -1)
  else:
    x = batch["spectrogram"]

  # Define the model
  with tf.name_scope("Model"):
    z = config.encode(x, hparams, is_training=False)
    xhat = config.decode(z, batch, hparams, is_training=False)

  # For interpolation
  tf.add_to_collection("x", x)
  tf.add_to_collection("pitch", batch["pitch"])
  tf.add_to_collection("z", z)
  tf.add_to_collection("xhat", xhat)

  total_loss = compute_mse_loss(x, xhat, hparams)

  # Define the metrics:
  names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
      "Loss": slim.metrics.mean(total_loss),
  })

  # Define the summaries
  for name, value in names_to_values.iteritems():
    slim.summaries.add_scalar_summary(value, name, print_summary=True)

  # Interpolate
  with tf.name_scope("Interpolation"):
    xhat = config.decode(z, batch, hparams, reuse=True, is_training=False)

    # Linear interpolation
    z_shift_one_example = tf.concat([z[1:], z[:1]], 0)
    z_linear_half = (z + z_shift_one_example) / 2.0
    xhat_linear_half = config.decode(z_linear_half, batch, hparams, reuse=True,
                                     is_training=False)

    # Pitch shift

    pitch_plus_2 = tf.clip_by_value(batch["pitch"] + 2, 0, 127)
    pitch_minus_2 = tf.clip_by_value(batch["pitch"] - 2, 0, 127)

    batch["pitch"] = pitch_minus_2
    xhat_pitch_minus_2 = config.decode(z, batch, hparams,
                                       reuse=True, is_training=False)
    batch["pitch"] = pitch_plus_2
    xhat_pitch_plus_2 = config.decode(z, batch, hparams,
                                      reuse=True, is_training=False)

  utils.specgram_summaries(x, "Training Examples", hparams, phase=phase)
  utils.specgram_summaries(xhat, "Reconstructions", hparams, phase=phase)
  utils.specgram_summaries(
      x - xhat, "Difference", hparams, audio=False, phase=phase)
  utils.specgram_summaries(
      xhat_linear_half, "Linear Interp. 0.5", hparams, phase=phase)
  utils.specgram_summaries(xhat_pitch_plus_2, "Pitch +2", hparams, phase=phase)
  utils.specgram_summaries(xhat_pitch_minus_2, "Pitch -2", hparams, phase=phase)

  return names_to_updates.values()