Exemple #1
0
def main(unused_argv):
    if not tf.gfile.IsDirectory(FLAGS.eval_dir):
        tf.gfile.MakeDirs(FLAGS.eval_dir)

    cfg, _ = get_named_config(FLAGS.model_cfg, FLAGS.model_cfg_overrides)

    # Load data
    with tf.name_scope("loader"):
        feat_dict = load_noteseqs(
            FLAGS.dataset_fp,
            cfg.eval_batch_size,
            cfg.eval_seq_len,
            max_discrete_times=cfg.data_max_discrete_times,
            max_discrete_velocities=cfg.data_max_discrete_velocities,
            augment_stretch_bounds=None,
            augment_transpose_bounds=None,
            randomize_chord_order=cfg.data_randomize_chord_order,
            repeat=False)

    # Build model
    with tf.variable_scope("phero_model"):
        model_dict = build_genie_model(feat_dict,
                                       cfg,
                                       cfg.eval_batch_size,
                                       cfg.eval_seq_len,
                                       is_training=False)
    genie_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                   scope="phero_model")

    # Build gold model
    eval_gold = False
    if cfg.stp_emb_vq or cfg.stp_emb_iq:
        eval_gold = True
        with tf.variable_scope("phero_model", reuse=True):
            gold_feat_dict = {
                "midi_pitches": tf.placeholder(tf.int32, [1, None]),
                "velocities": tf.placeholder(tf.int32, [1, None]),
                "delta_times_int": tf.placeholder(tf.int32, [1, None])
            }
            gold_seq_maxlen = gold.gold_longest()
            gold_seq_varlens = tf.placeholder(tf.int32, [1])
            gold_buttons = tf.placeholder(tf.int32, [1, None])
            gold_model_dict = build_genie_model(gold_feat_dict,
                                                cfg,
                                                1,
                                                gold_seq_maxlen,
                                                is_training=False,
                                                seq_varlens=gold_seq_varlens)

        gold_encodings = gold_model_dict["stp_emb_vq_discrete" if cfg.
                                         stp_emb_vq else "stp_emb_iq_discrete"]
        gold_mask = tf.sequence_mask(gold_seq_varlens,
                                     maxlen=gold_seq_maxlen,
                                     dtype=tf.float32)
        gold_diff = tf.cast(gold_buttons, tf.float32) - tf.cast(
            gold_encodings, tf.float32)
        gold_diff_l2 = tf.square(gold_diff)
        gold_diff_l1 = tf.abs(gold_diff)

        weighted_avg = lambda t, m: tf.reduce_sum(t * m) / tf.reduce_sum(m)

        gold_diff_l2 = weighted_avg(gold_diff_l2, gold_mask)
        gold_diff_l1 = weighted_avg(gold_diff_l1, gold_mask)

        gold_diff_l2_placeholder = tf.placeholder(tf.float32, [None])
        gold_diff_l1_placeholder = tf.placeholder(tf.float32, [None])

    summary_name_to_batch_tensor = {}

    # Summarize quantized step embeddings
    if cfg.stp_emb_vq:
        summary_name_to_batch_tensor["codebook_perplexity"] = model_dict[
            "stp_emb_vq_codebook_ppl"]
        summary_name_to_batch_tensor["loss_vqvae"] = model_dict[
            "stp_emb_vq_loss"]

    # Summarize integer-quantized step embeddings
    if cfg.stp_emb_iq:
        summary_name_to_batch_tensor["discrete_perplexity"] = model_dict[
            "stp_emb_iq_discrete_ppl"]
        summary_name_to_batch_tensor["iq_valid_p"] = model_dict[
            "stp_emb_iq_valid_p"]
        summary_name_to_batch_tensor["loss_iq_range"] = model_dict[
            "stp_emb_iq_range_penalty"]
        summary_name_to_batch_tensor["loss_iq_contour"] = model_dict[
            "stp_emb_iq_contour_penalty"]
        summary_name_to_batch_tensor["loss_iq_deviate"] = model_dict[
            "stp_emb_iq_deviate_penalty"]

    if cfg.stp_emb_vq or cfg.stp_emb_iq:
        summary_name_to_batch_tensor["contour_violation"] = model_dict[
            "contour_violation"]
        summary_name_to_batch_tensor["deviate_violation"] = model_dict[
            "deviate_violation"]

    # Summarize VAE sequence embeddings
    if cfg.seq_emb_vae:
        summary_name_to_batch_tensor["loss_kl"] = model_dict["seq_emb_vae_kl"]

    # Reconstruction loss
    summary_name_to_batch_tensor["loss_recons"] = model_dict["dec_recons_loss"]
    summary_name_to_batch_tensor["ppl_recons"] = tf.exp(
        model_dict["dec_recons_loss"])
    if cfg.dec_pred_velocity:
        summary_name_to_batch_tensor["loss_recons_velocity"] = model_dict[
            "dec_recons_velocity_loss"]
        summary_name_to_batch_tensor["ppl_recons_velocity"] = tf.exp(
            model_dict["dec_recons_velocity_loss"])

    # Create dataset summaries
    summaries = []
    summary_name_to_placeholder = {}
    for name in summary_name_to_batch_tensor:
        placeholder = tf.placeholder(tf.float32, [None])
        summary_name_to_placeholder[name] = placeholder
        summaries.append(tf.summary.scalar(name, tf.reduce_mean(placeholder)))
    if eval_gold:
        summary_name_to_placeholder["gold_diff_l2"] = gold_diff_l2_placeholder
        summaries.append(
            tf.summary.scalar("gold_diff_l2",
                              tf.reduce_mean(gold_diff_l2_placeholder)))
        summary_name_to_placeholder["gold_diff_l1"] = gold_diff_l1_placeholder
        summaries.append(
            tf.summary.scalar("gold_diff_l1",
                              tf.reduce_mean(gold_diff_l1_placeholder)))

    summaries = tf.summary.merge(summaries)
    summary_writer = tf.summary.FileWriter(FLAGS.eval_dir)

    # Create saver
    step = tf.train.get_or_create_global_step()
    saver = tf.train.Saver(genie_vars + [step], max_to_keep=None)

    def _eval_all(sess):
        """Gathers all metrics for a ckpt."""
        summaries = collections.defaultdict(list)

        if eval_gold:
            for midi_notes, buttons, seq_varlen in gold.gold_iterator([-6, 6]):
                gold_diff_l1_seq, gold_diff_l2_seq = sess.run(
                    [gold_diff_l1, gold_diff_l2], {
                        gold_feat_dict["midi_pitches"]:
                        midi_notes,
                        gold_feat_dict["delta_times_int"]:
                        np.ones_like(midi_notes) * 8,
                        gold_seq_varlens: [seq_varlen],
                        gold_buttons:
                        buttons
                    })
                summaries["gold_diff_l1"].append(gold_diff_l1_seq)
                summaries["gold_diff_l2"].append(gold_diff_l2_seq)

        while True:
            try:
                batches = sess.run(summary_name_to_batch_tensor)
            except tf.errors.OutOfRangeError:
                break

            for name, scalar in batches.items():
                summaries[name].append(scalar)

        return summaries

    # Eval
    if FLAGS.ckpt_fp is None:
        ckpt_fp = None
        while True:
            latest_ckpt_fp = tf.train.latest_checkpoint(FLAGS.train_dir)

            if latest_ckpt_fp != ckpt_fp:
                print("Eval: {}".format(latest_ckpt_fp))

                with tf.Session() as sess:
                    sess.run(tf.local_variables_initializer())
                    saver.restore(sess, latest_ckpt_fp)

                    ckpt_summaries = _eval_all(sess)
                    ckpt_summaries, ckpt_step = sess.run(
                        [summaries, step],
                        feed_dict={
                            summary_name_to_placeholder[n]: v
                            for n, v in ckpt_summaries.items()
                        })
                    summary_writer.add_summary(ckpt_summaries, ckpt_step)

                    saver.save(sess,
                               os.path.join(FLAGS.eval_dir, "ckpt"),
                               global_step=ckpt_step)

                print("Done")
                ckpt_fp = latest_ckpt_fp

            time.sleep(1)
    else:
        with tf.Session() as sess:
            sess.run(tf.local_variables_initializer())
            saver.restore(sess, FLAGS.ckpt_fp)

            ckpt_summaries = _eval_all(sess)
            ckpt_step = sess.run(step)

            print("-" * 80)
            print("Ckpt: {}".format(FLAGS.ckpt_fp))
            print("Step: {}".format(ckpt_step))
            for n, l in sorted(ckpt_summaries.items(), key=lambda x: x[0]):
                print("{}: {}".format(n, np.mean(l)))
Exemple #2
0
def main(unused_argv):
  if not tf.gfile.IsDirectory(FLAGS.train_dir):
    tf.gfile.MakeDirs(FLAGS.train_dir)

  cfg, cfg_summary = get_named_config(FLAGS.model_cfg,
                                      FLAGS.model_cfg_overrides)
  with tf.gfile.Open(os.path.join(FLAGS.train_dir, "cfg.txt"), "w") as f:
    f.write(cfg_summary)

  # Load data
  with tf.name_scope("loader"):
    feat_dict = load_noteseqs(
        FLAGS.dataset_fp,
        cfg.train_batch_size,
        cfg.train_seq_len,
        max_discrete_times=cfg.data_max_discrete_times,
        max_discrete_velocities=cfg.data_max_discrete_velocities,
        augment_stretch_bounds=cfg.train_augment_stretch_bounds,
        augment_transpose_bounds=cfg.train_augment_transpose_bounds,
        randomize_chord_order=cfg.data_randomize_chord_order,
        repeat=True)

  # Summarize data
  tf.summary.image(
      "piano_roll",
      util.discrete_to_piano_roll(util.demidify(feat_dict["midi_pitches"]), 88))

  # Build model
  with tf.variable_scope("phero_model"):
    model_dict = build_genie_model(
        feat_dict,
        cfg,
        cfg.train_batch_size,
        cfg.train_seq_len,
        is_training=True)

  # Summarize quantized step embeddings
  if cfg.stp_emb_vq:
    tf.summary.scalar("codebook_perplexity",
                      model_dict["stp_emb_vq_codebook_ppl"])
    tf.summary.image(
        "genie",
        util.discrete_to_piano_roll(
            model_dict["stp_emb_vq_discrete"],
            cfg.stp_emb_vq_codebook_size,
            dilation=max(1, 88 // cfg.stp_emb_vq_codebook_size)))
    tf.summary.scalar("loss_vqvae", model_dict["stp_emb_vq_loss"])

  # Summarize integer-quantized step embeddings
  if cfg.stp_emb_iq:
    tf.summary.scalar("discrete_perplexity",
                      model_dict["stp_emb_iq_discrete_ppl"])
    tf.summary.scalar("iq_valid_p", model_dict["stp_emb_iq_valid_p"])
    tf.summary.image(
        "genie",
        util.discrete_to_piano_roll(
            model_dict["stp_emb_iq_discrete"],
            cfg.stp_emb_iq_nbins,
            dilation=max(1, 88 // cfg.stp_emb_iq_nbins)))
    tf.summary.scalar("loss_iq_range", model_dict["stp_emb_iq_range_penalty"])
    tf.summary.scalar("loss_iq_contour",
                      model_dict["stp_emb_iq_contour_penalty"])
    tf.summary.scalar("loss_iq_deviate",
                      model_dict["stp_emb_iq_deviate_penalty"])

  if cfg.stp_emb_vq or cfg.stp_emb_iq:
    tf.summary.scalar("contour_violation", model_dict["contour_violation"])
    tf.summary.scalar("deviate_violation", model_dict["deviate_violation"])

  # Summarize VAE sequence embeddings
  if cfg.seq_emb_vae:
    tf.summary.scalar("loss_kl", model_dict["seq_emb_vae_kl"])

  # Summarize output
  tf.summary.image(
      "decoder_scores",
      util.discrete_to_piano_roll(model_dict["dec_recons_scores"], 88))
  tf.summary.image(
      "decoder_preds",
      util.discrete_to_piano_roll(model_dict["dec_recons_preds"], 88))
  if cfg.dec_pred_velocity:
    tf.summary.scalar("loss_recons_velocity",
                      model_dict["dec_recons_velocity_loss"])
    tf.summary.scalar("ppl_recons_velocity",
                      tf.exp(model_dict["dec_recons_velocity_loss"]))

  # Reconstruction loss
  tf.summary.scalar("loss_recons", model_dict["dec_recons_loss"])
  tf.summary.scalar("ppl_recons", tf.exp(model_dict["dec_recons_loss"]))

  # Build hybrid loss
  loss = model_dict["dec_recons_loss"]
  if cfg.stp_emb_vq and cfg.train_loss_vq_err_scalar > 0:
    loss += (cfg.train_loss_vq_err_scalar * model_dict["stp_emb_vq_loss"])
  if cfg.stp_emb_iq and cfg.train_loss_iq_range_scalar > 0:
    loss += (
        cfg.train_loss_iq_range_scalar * model_dict["stp_emb_iq_range_penalty"])
  if cfg.stp_emb_iq and cfg.train_loss_iq_contour_scalar > 0:
    loss += (
        cfg.train_loss_iq_contour_scalar *
        model_dict["stp_emb_iq_contour_penalty"])
  if cfg.stp_emb_iq and cfg.train_loss_iq_deviate_scalar > 0:
    loss += (
        cfg.train_loss_iq_deviate_scalar *
        model_dict["stp_emb_iq_deviate_penalty"])
  if cfg.seq_emb_vae and cfg.train_loss_vae_kl_scalar > 0:
    loss += (cfg.train_loss_vae_kl_scalar * model_dict["seq_emb_vae_kl"])
  if cfg.dec_pred_velocity:
    loss += model_dict["dec_recons_velocity_loss"]
  tf.summary.scalar("loss", loss)

  # Construct optimizer
  opt = tf.train.AdamOptimizer(learning_rate=cfg.train_lr)
  train_op = opt.minimize(
      loss, global_step=tf.train.get_or_create_global_step())

  # Train
  with tf.train.MonitoredTrainingSession(
      checkpoint_dir=FLAGS.train_dir,
      save_checkpoint_secs=600,
      save_summaries_secs=FLAGS.summary_every_nsecs) as sess:
    while True:
      sess.run(train_op)
Exemple #3
0
def main(unused_argv):
    if not tf.gfile.IsDirectory(FLAGS.train_dir):
        tf.gfile.MakeDirs(FLAGS.train_dir)

    cfg, cfg_summary = get_named_config(FLAGS.model_cfg,
                                        FLAGS.model_cfg_overrides)
    with tf.gfile.Open(os.path.join(FLAGS.train_dir, "cfg.txt"), "w") as f:
        f.write(cfg_summary)

    # Load data
    with tf.name_scope("loader"):
        feat_dict = load_noteseqs(
            FLAGS.dataset_fp,
            cfg.train_batch_size,
            cfg.train_seq_len,
            max_discrete_times=cfg.data_max_discrete_times,
            max_discrete_velocities=cfg.data_max_discrete_velocities,
            augment_stretch_bounds=cfg.train_augment_stretch_bounds,
            augment_transpose_bounds=cfg.train_augment_transpose_bounds,
            randomize_chord_order=cfg.data_randomize_chord_order,
            repeat=True)

    # Summarize data
    tf.summary.image(
        "piano_roll",
        util.discrete_to_piano_roll(util.demidify(feat_dict["midi_pitches"]),
                                    88))

    # Build model
    with tf.variable_scope("phero_model"):
        model_dict = build_genie_model(feat_dict,
                                       cfg,
                                       cfg.train_batch_size,
                                       cfg.train_seq_len,
                                       is_training=True)

    # Summarize quantized step embeddings
    if cfg.stp_emb_vq:
        tf.summary.scalar("codebook_perplexity",
                          model_dict["stp_emb_vq_codebook_ppl"])
        tf.summary.image(
            "genie",
            util.discrete_to_piano_roll(
                model_dict["stp_emb_vq_discrete"],
                cfg.stp_emb_vq_codebook_size,
                dilation=max(1, 88 // cfg.stp_emb_vq_codebook_size)))
        tf.summary.scalar("loss_vqvae", model_dict["stp_emb_vq_loss"])

    # Summarize integer-quantized step embeddings
    if cfg.stp_emb_iq:
        tf.summary.scalar("discrete_perplexity",
                          model_dict["stp_emb_iq_discrete_ppl"])
        tf.summary.scalar("iq_valid_p", model_dict["stp_emb_iq_valid_p"])
        tf.summary.image(
            "genie",
            util.discrete_to_piano_roll(model_dict["stp_emb_iq_discrete"],
                                        cfg.stp_emb_iq_nbins,
                                        dilation=max(
                                            1, 88 // cfg.stp_emb_iq_nbins)))
        tf.summary.scalar("loss_iq_range",
                          model_dict["stp_emb_iq_range_penalty"])
        tf.summary.scalar("loss_iq_contour",
                          model_dict["stp_emb_iq_contour_penalty"])
        tf.summary.scalar("loss_iq_deviate",
                          model_dict["stp_emb_iq_deviate_penalty"])

    if cfg.stp_emb_vq or cfg.stp_emb_iq:
        tf.summary.scalar("contour_violation", model_dict["contour_violation"])
        tf.summary.scalar("deviate_violation", model_dict["deviate_violation"])

    # Summarize VAE sequence embeddings
    if cfg.seq_emb_vae:
        tf.summary.scalar("loss_kl", model_dict["seq_emb_vae_kl"])

    # Summarize output
    tf.summary.image(
        "decoder_scores",
        util.discrete_to_piano_roll(model_dict["dec_recons_scores"], 88))
    tf.summary.image(
        "decoder_preds",
        util.discrete_to_piano_roll(model_dict["dec_recons_preds"], 88))
    if cfg.dec_pred_velocity:
        tf.summary.scalar("loss_recons_velocity",
                          model_dict["dec_recons_velocity_loss"])
        tf.summary.scalar("ppl_recons_velocity",
                          tf.exp(model_dict["dec_recons_velocity_loss"]))

    # Reconstruction loss
    tf.summary.scalar("loss_recons", model_dict["dec_recons_loss"])
    tf.summary.scalar("ppl_recons", tf.exp(model_dict["dec_recons_loss"]))

    # Build hybrid loss
    loss = model_dict["dec_recons_loss"]
    if cfg.stp_emb_vq and cfg.train_loss_vq_err_scalar > 0:
        loss += (cfg.train_loss_vq_err_scalar * model_dict["stp_emb_vq_loss"])
    if cfg.stp_emb_iq and cfg.train_loss_iq_range_scalar > 0:
        loss += (cfg.train_loss_iq_range_scalar *
                 model_dict["stp_emb_iq_range_penalty"])
    if cfg.stp_emb_iq and cfg.train_loss_iq_contour_scalar > 0:
        loss += (cfg.train_loss_iq_contour_scalar *
                 model_dict["stp_emb_iq_contour_penalty"])
    if cfg.stp_emb_iq and cfg.train_loss_iq_deviate_scalar > 0:
        loss += (cfg.train_loss_iq_deviate_scalar *
                 model_dict["stp_emb_iq_deviate_penalty"])
    if cfg.seq_emb_vae and cfg.train_loss_vae_kl_scalar > 0:
        loss += (cfg.train_loss_vae_kl_scalar * model_dict["seq_emb_vae_kl"])
    if cfg.dec_pred_velocity:
        loss += model_dict["dec_recons_velocity_loss"]
    tf.summary.scalar("loss", loss)

    # Construct optimizer
    opt = tf.train.AdamOptimizer(learning_rate=cfg.train_lr)
    train_op = opt.minimize(loss,
                            global_step=tf.train.get_or_create_global_step())

    # Train
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=FLAGS.train_dir,
            save_checkpoint_secs=600,
            save_summaries_secs=FLAGS.summary_every_nsecs) as sess:
        while True:
            sess.run(train_op)
Exemple #4
0
def main(unused_argv):
  if not tf.gfile.IsDirectory(FLAGS.eval_dir):
    tf.gfile.MakeDirs(FLAGS.eval_dir)

  cfg, _ = get_named_config(FLAGS.model_cfg, FLAGS.model_cfg_overrides)

  # Load data
  with tf.name_scope("loader"):
    feat_dict = load_noteseqs(
        FLAGS.dataset_fp,
        cfg.eval_batch_size,
        cfg.eval_seq_len,
        max_discrete_times=cfg.data_max_discrete_times,
        max_discrete_velocities=cfg.data_max_discrete_velocities,
        augment_stretch_bounds=None,
        augment_transpose_bounds=None,
        randomize_chord_order=cfg.data_randomize_chord_order,
        repeat=False)

  # Build model
  with tf.variable_scope("phero_model"):
    model_dict = build_genie_model(
        feat_dict,
        cfg,
        cfg.eval_batch_size,
        cfg.eval_seq_len,
        is_training=False)
  genie_vars = tf.get_collection(
      tf.GraphKeys.GLOBAL_VARIABLES, scope="phero_model")

  # Build gold model
  eval_gold = False
  if cfg.stp_emb_vq or cfg.stp_emb_iq:
    eval_gold = True
    with tf.variable_scope("phero_model", reuse=True):
      gold_feat_dict = {
          "midi_pitches": tf.placeholder(tf.int32, [1, None]),
          "velocities": tf.placeholder(tf.int32, [1, None]),
          "delta_times_int": tf.placeholder(tf.int32, [1, None])
      }
      gold_seq_maxlen = gold.gold_longest()
      gold_seq_varlens = tf.placeholder(tf.int32, [1])
      gold_buttons = tf.placeholder(tf.int32, [1, None])
      gold_model_dict = build_genie_model(
          gold_feat_dict,
          cfg,
          1,
          gold_seq_maxlen,
          is_training=False,
          seq_varlens=gold_seq_varlens)

    gold_encodings = gold_model_dict[
        "stp_emb_vq_discrete" if cfg.stp_emb_vq else "stp_emb_iq_discrete"]
    gold_mask = tf.sequence_mask(
        gold_seq_varlens, maxlen=gold_seq_maxlen, dtype=tf.float32)
    gold_diff = tf.cast(gold_buttons, tf.float32) - tf.cast(
        gold_encodings, tf.float32)
    gold_diff_l2 = tf.square(gold_diff)
    gold_diff_l1 = tf.abs(gold_diff)

    weighted_avg = lambda t, m: tf.reduce_sum(t * m) / tf.reduce_sum(m)

    gold_diff_l2 = weighted_avg(gold_diff_l2, gold_mask)
    gold_diff_l1 = weighted_avg(gold_diff_l1, gold_mask)

    gold_diff_l2_placeholder = tf.placeholder(tf.float32, [None])
    gold_diff_l1_placeholder = tf.placeholder(tf.float32, [None])

  summary_name_to_batch_tensor = {}

  # Summarize quantized step embeddings
  if cfg.stp_emb_vq:
    summary_name_to_batch_tensor["codebook_perplexity"] = model_dict[
        "stp_emb_vq_codebook_ppl"]
    summary_name_to_batch_tensor["loss_vqvae"] = model_dict["stp_emb_vq_loss"]

  # Summarize integer-quantized step embeddings
  if cfg.stp_emb_iq:
    summary_name_to_batch_tensor["discrete_perplexity"] = model_dict[
        "stp_emb_iq_discrete_ppl"]
    summary_name_to_batch_tensor["iq_valid_p"] = model_dict[
        "stp_emb_iq_valid_p"]
    summary_name_to_batch_tensor["loss_iq_range"] = model_dict[
        "stp_emb_iq_range_penalty"]
    summary_name_to_batch_tensor["loss_iq_contour"] = model_dict[
        "stp_emb_iq_contour_penalty"]
    summary_name_to_batch_tensor["loss_iq_deviate"] = model_dict[
        "stp_emb_iq_deviate_penalty"]

  if cfg.stp_emb_vq or cfg.stp_emb_iq:
    summary_name_to_batch_tensor["contour_violation"] = model_dict[
        "contour_violation"]
    summary_name_to_batch_tensor["deviate_violation"] = model_dict[
        "deviate_violation"]

  # Summarize VAE sequence embeddings
  if cfg.seq_emb_vae:
    summary_name_to_batch_tensor["loss_kl"] = model_dict["seq_emb_vae_kl"]

  # Reconstruction loss
  summary_name_to_batch_tensor["loss_recons"] = model_dict["dec_recons_loss"]
  summary_name_to_batch_tensor["ppl_recons"] = tf.exp(
      model_dict["dec_recons_loss"])
  if cfg.dec_pred_velocity:
    summary_name_to_batch_tensor["loss_recons_velocity"] = model_dict[
        "dec_recons_velocity_loss"]
    summary_name_to_batch_tensor["ppl_recons_velocity"] = tf.exp(
        model_dict["dec_recons_velocity_loss"])

  # Create dataset summaries
  summaries = []
  summary_name_to_placeholder = {}
  for name in summary_name_to_batch_tensor:
    placeholder = tf.placeholder(tf.float32, [None])
    summary_name_to_placeholder[name] = placeholder
    summaries.append(tf.summary.scalar(name, tf.reduce_mean(placeholder)))
  if eval_gold:
    summary_name_to_placeholder["gold_diff_l2"] = gold_diff_l2_placeholder
    summaries.append(
        tf.summary.scalar("gold_diff_l2",
                          tf.reduce_mean(gold_diff_l2_placeholder)))
    summary_name_to_placeholder["gold_diff_l1"] = gold_diff_l1_placeholder
    summaries.append(
        tf.summary.scalar("gold_diff_l1",
                          tf.reduce_mean(gold_diff_l1_placeholder)))

  summaries = tf.summary.merge(summaries)
  summary_writer = tf.summary.FileWriter(FLAGS.eval_dir)

  # Create saver
  step = tf.train.get_or_create_global_step()
  saver = tf.train.Saver(genie_vars + [step], max_to_keep=None)

  def _eval_all(sess):
    """Gathers all metrics for a ckpt."""
    summaries = collections.defaultdict(list)

    if eval_gold:
      for midi_notes, buttons, seq_varlen in gold.gold_iterator([-6, 6]):
        gold_diff_l1_seq, gold_diff_l2_seq = sess.run(
            [gold_diff_l1, gold_diff_l2], {
                gold_feat_dict["midi_pitches"]:
                    midi_notes,
                gold_feat_dict["delta_times_int"]:
                    np.ones_like(midi_notes) * 8,
                gold_seq_varlens: [seq_varlen],
                gold_buttons: buttons
            })
        summaries["gold_diff_l1"].append(gold_diff_l1_seq)
        summaries["gold_diff_l2"].append(gold_diff_l2_seq)

    while True:
      try:
        batches = sess.run(summary_name_to_batch_tensor)
      except tf.errors.OutOfRangeError:
        break

      for name, scalar in batches.items():
        summaries[name].append(scalar)

    return summaries

  # Eval
  if FLAGS.ckpt_fp is None:
    ckpt_fp = None
    while True:
      latest_ckpt_fp = tf.train.latest_checkpoint(FLAGS.train_dir)

      if latest_ckpt_fp != ckpt_fp:
        print("Eval: {}".format(latest_ckpt_fp))

        with tf.Session() as sess:
          sess.run(tf.local_variables_initializer())
          saver.restore(sess, latest_ckpt_fp)

          ckpt_summaries = _eval_all(sess)
          ckpt_summaries, ckpt_step = sess.run(
              [summaries, step],
              feed_dict={
                  summary_name_to_placeholder[n]: v
                  for n, v in ckpt_summaries.items()
              })
          summary_writer.add_summary(ckpt_summaries, ckpt_step)

          saver.save(
              sess, os.path.join(FLAGS.eval_dir, "ckpt"), global_step=ckpt_step)

        print("Done")
        ckpt_fp = latest_ckpt_fp

      time.sleep(1)
  else:
    with tf.Session() as sess:
      sess.run(tf.local_variables_initializer())
      saver.restore(sess, FLAGS.ckpt_fp)

      ckpt_summaries = _eval_all(sess)
      ckpt_step = sess.run(step)

      print("-" * 80)
      print("Ckpt: {}".format(FLAGS.ckpt_fp))
      print("Step: {}".format(ckpt_step))
      for n, l in sorted(ckpt_summaries.items(), key=lambda x: x[0]):
        print("{}: {}".format(n, np.mean(l)))