예제 #1
0
  def __init__(self, checkpoint=None, bundle=None):
    details = generator_pb2.GeneratorDetails(
        id='test_generator',
        description='Test Generator')

    super(TestSequenceGenerator, self).__init__(
        TestModel(), details, checkpoint, bundle)
예제 #2
0
def create_generator(checkpoint, steps_per_beat=4, hparams=None):
    melody_encoder_decoder = basic_rnn_encoder_decoder.MelodyEncoderDecoder()
    details = generator_pb2.GeneratorDetails(id='basic_rnn',
                                             description='Basic RNN Generator')
    return melody_rnn_sequence_generator.MelodyRnnSequenceGenerator(
        details, checkpoint, melody_encoder_decoder,
        basic_rnn_graph.build_graph, steps_per_beat,
        {} if hparams is None else hparams)
예제 #3
0
def create_generator(train_dir, steps_per_beat=4, hparams=None):
    melody_encoder_decoder = attention_rnn_encoder_decoder.MelodyEncoderDecoder(
    )
    details = generator_pb2.GeneratorDetails(
        id='attention_rnn', description='Attention RNN Generator')
    return melody_rnn_sequence_generator.MelodyRnnSequenceGenerator(
        details, train_dir, melody_encoder_decoder,
        attention_rnn_graph.build_graph, steps_per_beat,
        {} if hparams is None else hparams)
예제 #4
0
def main(unused_argv):
    melody_encoder_decoder = basic_rnn_encoder_decoder.MelodyEncoderDecoder()
    details = generator_pb2.GeneratorDetails(id='basic_rnn',
                                             description='Basic RNN Generator')
    with melody_rnn_sequence_generator.MelodyRnnSequenceGenerator(
            details, melody_rnn_generate.get_train_dir(),
            melody_encoder_decoder, basic_rnn_graph.build_graph,
            melody_rnn_generate.get_steps_per_beat(),
            melody_rnn_generate.get_hparams()) as generator:
        melody_rnn_generate.run_with_flags(generator)
def create_generator(checkpoint,
                     bundle,
                     steps_per_quarter=4,
                     hparams=None,
                     generator_id=DEFAULT_ID):
    melody_encoder_decoder = basic_rnn_encoder_decoder.MelodyEncoderDecoder()
    details = generator_pb2.GeneratorDetails(id=generator_id,
                                             description='Basic RNN Generator')
    return melody_rnn_sequence_generator.MelodyRnnSequenceGenerator(
        details, checkpoint, bundle, melody_encoder_decoder,
        basic_rnn_graph.build_graph, steps_per_quarter,
        {} if hparams is None else hparams)
    def testUseMatchingGeneratorId(self):
        bundle = generator_pb2.GeneratorBundle(
            generator_details=generator_pb2.GeneratorDetails(
                id='test_generator'),
            checkpoint_file=[b'foo.ckpt'],
            metagraph_file=b'foo.ckpt.meta')

        SeuenceGenerator(bundle=bundle)

        bundle.generator_details.id = 'blarg'

        with self.assertRaises(sequence_generator.SequenceGeneratorError):
            SeuenceGenerator(bundle=bundle)
예제 #7
0
def create_generator(checkpoint, bundle, steps_per_quarter=4, hparams=None):
  melody_encoder_decoder = attention_rnn_encoder_decoder.MelodyEncoderDecoder()
  details = generator_pb2.GeneratorDetails(
      id='attention_rnn',
      description='Attention RNN Generator')
  return melody_rnn_sequence_generator.MelodyRnnSequenceGenerator(
      details,
      checkpoint,
      bundle,
      melody_encoder_decoder,
      attention_rnn_graph.build_graph,
      steps_per_quarter,
      {} if hparams is None else hparams)
    def testSpecifyEitherCheckPointOrBundle(self):
        bundle = generator_pb2.GeneratorBundle(
            generator_details=generator_pb2.GeneratorDetails(
                id='test_generator'),
            checkpoint_file=[b'foo.ckpt'],
            metagraph_file=b'foo.ckpt.meta')

        with self.assertRaises(sequence_generator.SequenceGeneratorError):
            SeuenceGenerator(checkpoint='foo.ckpt', bundle=bundle)
        with self.assertRaises(sequence_generator.SequenceGeneratorError):
            SeuenceGenerator(checkpoint=None, bundle=None)

        SeuenceGenerator(checkpoint='foo.ckpt')
        SeuenceGenerator(bundle=bundle)
    def testGetBundleDetails(self):
        # Test with non-bundle generator.
        seq_gen = SeuenceGenerator(checkpoint='foo.ckpt')
        self.assertEqual(None, seq_gen.bundle_details)

        # Test with bundle-based generator.
        bundle_details = generator_pb2.GeneratorBundle.BundleDetails(
            description='bundle of joy')
        bundle = generator_pb2.GeneratorBundle(
            generator_details=generator_pb2.GeneratorDetails(
                id='test_generator'),
            bundle_details=bundle_details,
            checkpoint_file=[b'foo.ckpt'],
            metagraph_file=b'foo.ckpt.meta')
        seq_gen = SeuenceGenerator(bundle=bundle)
        self.assertEqual(bundle_details, seq_gen.bundle_details)
예제 #10
0
def main(unused_argv):

    tf.logging.set_verbosity(tf.logging.INFO)

    if not FLAGS.checkpoint_number:
        tf.logging.fatal('--checkpoint_number required')
        return

    model_dict = {'LSTM': LSTMModel, 'LSTMAE': LSTMAE}
    checkpoint_dir = FLAGS.checkpoint_dir

    data_config = models.default_configs['performance_with_meta_128']
    data_config.hparams.parse(FLAGS.hparams)
    encoder_decoder = data_config.encoder_decoder

    layers = FLAGS.layers
    batch_size = 1

    if FLAGS.classifier_weight:
        config = LSTMConfig(encoder_decoder=encoder_decoder,
                            gpu=FLAGS.gpu,
                            layers=layers,
                            batch_size=batch_size,
                            label_classifier_weight=FLAGS.classifier_weight,
                            label_classifier_units=4)
    else:
        config = LSTMConfig(encoder_decoder=encoder_decoder,
                            gpu=FLAGS.gpu,
                            layers=layers,
                            batch_size=batch_size)

    model = model_dict[FLAGS.model](config, 'generate')
    number = FLAGS.checkpoint_number
    bundle_file = checkpoint_dir + 'bundle.mag'
    checkpoint_filename = os.path.join(checkpoint_dir,
                                       'model.ckpt-' + str(number))
    metagraph_filename = os.path.join(checkpoint_dir,
                                      'model.ckpt-' + str(number) + '.meta')

    tf.logging.info(checkpoint_dir)
    with tf.Graph().as_default() as g:

        model.build_graph_fn()
        sess = tf.Session(graph=g)
        saver = tf.train.Saver()

        tf.logging.info('Checkpoint used: %s', checkpoint_filename)
        saver.restore(sess, checkpoint_filename)

        try:
            tempdir = tempfile.mkdtemp()
            checkpoint_filename = os.path.join(tempdir, 'model.ckpt')
            saver = tf.train.Saver(sharded=False,
                                   write_version=tf.train.SaverDef.V1)
            saver.save(sess,
                       checkpoint_filename,
                       meta_graph_suffix='meta',
                       write_meta_graph=True)
            metagraph_filename = checkpoint_filename + '.meta'
            bundle = generator_pb2.GeneratorBundle()

            details = generator_pb2.GeneratorDetails(
                id='performance_with_dynamics',
                description='Performance RNN with dynamics (compact input)')
            bundle.generator_details.CopyFrom(details)

            with tf.gfile.Open(checkpoint_filename, 'rb') as f:
                bundle.checkpoint_file.append(f.read())
            with tf.gfile.Open(metagraph_filename, 'rb') as f:
                bundle.metagraph_file = f.read()

            tf.logging.info('Writing to: ' + bundle_file)
            with tf.gfile.Open(bundle_file, 'wb') as f:
                f.write(bundle.SerializeToString())

        finally:
            if tempdir is not None:
                tf.gfile.DeleteRecursively(tempdir)