Esempio n. 1
0
def main(unused_argv):
  if FLAGS.checkpoint is None or not FLAGS.checkpoint:
    raise ValueError(
        "Need to provide a path to checkpoint directory.")
  wmodel = instantiate_model(FLAGS.checkpoint)
  generator = Generator(wmodel, FLAGS.strategy)
  midi_outs = generator.run_generation(
      gen_batch_size=FLAGS.gen_batch_size, piece_length=FLAGS.piece_length)

  # Creates a folder for storing the process of the sampling.
  label = "sample_%s_%s_%s_T%g_l%i_%.2fmin" % (lib_util.timestamp(),
                                               FLAGS.strategy,
                                               generator.hparams.architecture,
                                               FLAGS.temperature,
                                               FLAGS.piece_length,
                                               generator.time_taken)
  basepath = os.path.join(FLAGS.generation_output_dir, label)
  print("basepath:", basepath)
  tf.gfile.MakeDirs(basepath)

  # Stores all the (intermediate) steps.
  intermediate_steps_path = os.path.join(basepath, "intermediate_steps.npz")
  with lib_util.timing("writing_out_sample_npz"):
    print("Writing intermediate steps to", intermediate_steps_path)
    generator.logger.dump(intermediate_steps_path)

  # Saves the results as midi or returns as midi out.
  midi_path = os.path.join(basepath, "midi")
  tf.gfile.MakeDirs(midi_path)
  print("Made directory %s" % midi_path)
  save_midis(midi_outs, midi_path, label)

  result_npy_save_path = os.path.join(basepath, "generated_result.npy")
  print("Writing final result to", result_npy_save_path)
  with tf.gfile.Open(result_npy_save_path, "w") as p:
    np.save(p, generator.pianorolls)

  # Save the prime as midi and npy if in harmonization mode.
  # First, checks the stored npz for the first (context) and last step.
  print("Reading to check", intermediate_steps_path)
  with tf.gfile.Open(intermediate_steps_path, "r") as p:
    foo = np.load(p)
    for key in foo.keys():
      if re.match(r"0_root/.*?_strategy/.*?_context/0_pianorolls", key):
        context_rolls = foo[key]
        context_fpath = os.path.join(basepath, "context.npy")
        print("Writing context to", context_fpath)
        with lib_util.atomic_file(context_fpath) as context_p:
          np.save(context_p, context_rolls)
        if "harm" in FLAGS.strategy:
          # Only synthesize the one prime if in Midi-melody-prime mode.
          primes = context_rolls
          if "Melody" in FLAGS.strategy:
            primes = [context_rolls[0]]
          prime_midi_outs = get_midi_from_pianorolls(primes, generator.decoder)
          save_midis(prime_midi_outs, midi_path, label + "_prime")
        break
  print("Done")
Esempio n. 2
0
def evaluate_paths(paths, evaluator, unused_hparams, eval_logdir):
    """Evaluates negative loglikelihood of pianorolls from given paths."""
    for path in paths:
        name = 'eval_samples_%s_%s_ensemble%s_chrono%s' % (lib_util.timestamp(
        ), FLAGS.unit, FLAGS.ensemble_size, FLAGS.chronological)
        log_fname = '%s__%s.npz' % (os.path.splitext(
            os.path.basename(path))[0], name)
        log_fpath = os.path.join(eval_logdir, log_fname)

        pianorolls = get_path_pianorolls(path)
        rval = lib_evaluation.evaluate(evaluator, pianorolls)
        tf.logging.info('Writing evaluation statistics to %s', log_fpath)
        with lib_util.atomic_file(log_fpath) as p:
            np.savez_compressed(p, **rval)
Esempio n. 3
0
def evaluate_fold(fold, evaluator, hparams, eval_logdir, checkpoint_dir):
  """Writes to file the neg. loglikelihood of given fold (train/valid/test)."""
  eval_run_name = 'eval_%s_%s%s_%s_ensemble%s_chrono%s' % (
      lib_util.timestamp(), fold,
      '' if FLAGS.fold_index is None else FLAGS.fold_index, FLAGS.unit,
      FLAGS.ensemble_size, FLAGS.chronological)
  log_fname = '%s__%s.npz' % (os.path.basename(checkpoint_dir), eval_run_name)
  log_fpath = os.path.join(eval_logdir, log_fname)

  pianorolls = get_fold_pianorolls(fold, hparams)

  rval = lib_evaluation.evaluate(evaluator, pianorolls)
  tf.logging.info('Writing to path: %s' % log_fpath)
  with lib_util.atomic_file(log_fpath) as p:
    np.savez_compressed(p, **rval)
Esempio n. 4
0
def evaluate_fold(fold, evaluator, hparams, eval_logdir, checkpoint_dir):
  """Writes to file the neg. loglikelihood of given fold (train/valid/test)."""
  eval_run_name = 'eval_%s_%s%s_%s_ensemble%s_chrono%s' % (
      lib_util.timestamp(), fold,
      '' if FLAGS.fold_index is None else FLAGS.fold_index, FLAGS.unit,
      FLAGS.ensemble_size, FLAGS.chronological)
  log_fname = '%s__%s.npz' % (os.path.basename(checkpoint_dir), eval_run_name)
  log_fpath = os.path.join(eval_logdir, log_fname)

  pianorolls = get_fold_pianorolls(fold, hparams)

  rval = lib_evaluation.evaluate(evaluator, pianorolls)
  tf.logging.info('Writing to path: %s' % log_fpath)
  with lib_util.atomic_file(log_fpath) as p:
    np.savez_compressed(p, **rval)
Esempio n. 5
0
def evaluate_paths(paths, evaluator, unused_hparams, eval_logdir):
  """Evaluates negative loglikelihood of pianorolls from given paths."""
  for path in paths:
    name = 'eval_samples_%s_%s_ensemble%s_chrono%s' % (lib_util.timestamp(),
                                                       FLAGS.unit,
                                                       FLAGS.ensemble_size,
                                                       FLAGS.chronological)
    log_fname = '%s__%s.npz' % (os.path.splitext(os.path.basename(path))[0],
                                name)
    log_fpath = os.path.join(eval_logdir, log_fname)

    pianorolls = get_path_pianorolls(path)
    rval = lib_evaluation.evaluate(evaluator, pianorolls)
    tf.logging.info('Writing evaluation statistics to %s', log_fpath)
    with lib_util.atomic_file(log_fpath) as p:
      np.savez_compressed(p, **rval)
Esempio n. 6
0
def main(unused_argv):
  if FLAGS.checkpoint is None or not FLAGS.checkpoint:
    raise ValueError(
        "Need to provide a path to checkpoint directory.")

  if FLAGS.tfsample:
    generator = TFGenerator(FLAGS.checkpoint)
  else:
    wmodel = instantiate_model(FLAGS.checkpoint)
    generator = Generator(wmodel, FLAGS.strategy)
  midi_outs = generator.run_generation(
      gen_batch_size=FLAGS.gen_batch_size, piece_length=FLAGS.piece_length)

  # Creates a folder for storing the process of the sampling.
  label = "sample_%s_%s_%s_T%g_l%i_%.2fmin" % (lib_util.timestamp(),
                                               FLAGS.strategy,
                                               generator.hparams.architecture,
                                               FLAGS.temperature,
                                               FLAGS.piece_length,
                                               generator.time_taken)
  basepath = os.path.join(FLAGS.generation_output_dir, label)
  tf.logging.info("basepath: %s", basepath)
  tf.gfile.MakeDirs(basepath)

  # Saves the results as midi or returns as midi out.
  midi_path = os.path.join(basepath, "midi")
  tf.gfile.MakeDirs(midi_path)
  tf.logging.info("Made directory %s", midi_path)
  save_midis(midi_outs, midi_path, label)

  result_npy_save_path = os.path.join(basepath, "generated_result.npy")
  tf.logging.info("Writing final result to %s", result_npy_save_path)
  with tf.gfile.Open(result_npy_save_path, "w") as p:
    np.save(p, generator.pianorolls)

  if FLAGS.tfsample:
    tf.logging.info("Done")
    return

  # Stores all the (intermediate) steps.
  intermediate_steps_path = os.path.join(basepath, "intermediate_steps.npz")
  with lib_util.timing("writing_out_sample_npz"):
    tf.logging.info("Writing intermediate steps to %s", intermediate_steps_path)
    generator.logger.dump(intermediate_steps_path)

  # Save the prime as midi and npy if in harmonization mode.
  # First, checks the stored npz for the first (context) and last step.
  tf.logging.info("Reading to check %s", intermediate_steps_path)
  with tf.gfile.Open(intermediate_steps_path, "r") as p:
    foo = np.load(p)
    for key in foo.keys():
      if re.match(r"0_root/.*?_strategy/.*?_context/0_pianorolls", key):
        context_rolls = foo[key]
        context_fpath = os.path.join(basepath, "context.npy")
        tf.logging.info("Writing context to %s", context_fpath)
        with lib_util.atomic_file(context_fpath) as context_p:
          np.save(context_p, context_rolls)
        if "harm" in FLAGS.strategy:
          # Only synthesize the one prime if in Midi-melody-prime mode.
          primes = context_rolls
          if "Melody" in FLAGS.strategy:
            primes = [context_rolls[0]]
          prime_midi_outs = get_midi_from_pianorolls(primes, generator.decoder)
          save_midis(prime_midi_outs, midi_path, label + "_prime")
        break
  tf.logging.info("Done")
Esempio n. 7
0
def main(checkpoint,
         tfsample=True,
         strategy='igibbs',
         gen_batch_size="3",
         piece_length="32",
         temperature=0.99,
         generation_output_dir=None,
         prime_midi_melody_fpath=None):
    if checkpoint is None or not checkpoint:
        raise ValueError("Need to provide a path to checkpoint directory.")

    midi_file = None
    if prime_midi_melody_fpath is not None and prime_midi_melody_fpath is not '':
        midi_file = pretty_midi.PrettyMIDI(prime_midi_melody_fpath)
    if tfsample:
        generator = TFGenerator(checkpoint)
    else:
        wmodel = instantiate_model(checkpoint)
        generator = Generator(wmodel, strategy)
    midi_outs = generator.run_generation(temperature=temperature,
                                         gen_batch_size=gen_batch_size,
                                         piece_length=piece_length,
                                         midi_in=midi_file)

    # Creates a folder for storing the process of the sampling.
    label = "sample_%s_%s_%s_T%g_l%i_%.2fmin" % (
        lib_util.timestamp(), strategy, generator.hparams.architecture,
        temperature, piece_length, generator.time_taken)
    basepath = os.path.join(generation_output_dir, label)
    tf.logging.info("basepath: %s", basepath)
    tf.gfile.MakeDirs(basepath)

    # Saves the results as midi or returns as midi out.
    midi_path = os.path.join(basepath, "midi")
    tf.gfile.MakeDirs(midi_path)
    tf.logging.info("Made directory %s", midi_path)
    save_midis(midi_outs, midi_path, label)

    result_npy_save_path = os.path.join(basepath, "generated_result.npy")
    tf.logging.info("Writing final result to %s", result_npy_save_path)
    with tf.gfile.Open(result_npy_save_path, "w") as p:
        np.save(p, generator.pianorolls)

    if tfsample:
        tf.logging.info("Done")
        return

    # Stores all the (intermediate) steps.
    intermediate_steps_path = os.path.join(basepath, "intermediate_steps.npz")
    with lib_util.timing("writing_out_sample_npz"):
        tf.logging.info("Writing intermediate steps to %s",
                        intermediate_steps_path)
        generator.logger.dump(intermediate_steps_path)

    # Save the prime as midi and npy if in harmonization mode.
    # First, checks the stored npz for the first (context) and last step.
    tf.logging.info("Reading to check %s", intermediate_steps_path)
    with tf.gfile.Open(intermediate_steps_path, "rb") as p:
        foo = np.load(p, allow_pickle=True, encoding='latin1')
        for key in foo.keys():
            if re.match(r"0_root/.*?_strategy/.*?_context/0_pianorolls", key):
                context_rolls = foo[key]
                context_fpath = os.path.join(basepath, "context.npy")
                tf.logging.info("Writing context to %s", context_fpath)
                with lib_util.atomic_file(context_fpath) as context_p:
                    np.save(context_p, context_rolls)
                if "harm" in strategy:
                    # Only synthesize the one prime if in Midi-melody-prime mode.
                    primes = context_rolls
                    if "Melody" in strategy:
                        primes = [context_rolls[0]]
                    prime_midi_outs = get_midi_from_pianorolls(
                        primes, generator.decoder)
                    save_midis(prime_midi_outs, midi_path, label + "_prime")
                break
    tf.logging.info("Done")