Beispiel #1
0
    def create_bundle_file(self, bundle_file, bundle_description=None):
        """Writes a generator_pb2.GeneratorBundle file in the specified location.

    Saves the checkpoint, metagraph, and generator id in one file.

    Args:
      bundle_file: Location to write the bundle file.
      bundle_description: A short, human-readable string description of this
          bundle.

    Raises:
      SequenceGeneratorException: if there is an error creating the bundle file.
    """
        if not bundle_file:
            raise SequenceGeneratorException(
                'Bundle file location not specified.')
        if not self.details.id:
            raise SequenceGeneratorException(
                'Generator id must be included in GeneratorDetails when creating '
                'a bundle file.')

        if not self.details.description:
            tf.logging.warn(
                'Writing bundle file with no generator description.')
        if not bundle_description:
            tf.logging.warn('Writing bundle file with no bundle description.')

        self.initialize()

        tempdir = None
        try:
            tempdir = tempfile.mkdtemp()
            checkpoint_filename = os.path.join(tempdir, 'model.ckpt')

            self._model.write_checkpoint_with_metagraph(checkpoint_filename)

            if not os.path.isfile(checkpoint_filename):
                raise SequenceGeneratorException(
                    'Could not read checkpoint file: %s' %
                    (checkpoint_filename))
            metagraph_filename = checkpoint_filename + '.meta'
            if not os.path.isfile(metagraph_filename):
                raise SequenceGeneratorException(
                    'Could not read metagraph file: %s' % (metagraph_filename))

            bundle = generator_pb2.GeneratorBundle()
            bundle.generator_details.CopyFrom(self.details)
            if bundle_description:
                bundle.bundle_details.description = bundle_description
            with tf.gfile.Open(checkpoint_filename, 'rb') as f:
                bundle.checkpoint_file.append(f.read())
            with tf.gfile.Open(metagraph_filename, 'rb') as f:
                bundle.metagraph_file = f.read()

            with tf.gfile.Open(bundle_file, 'wb') as f:
                f.write(bundle.SerializeToString())
        finally:
            if tempdir is not None:
                tf.gfile.DeleteRecursively(tempdir)
def read_bundle_file(bundle_file):
  # Read in bundle file.
  bundle = generator_pb2.GeneratorBundle()
  with tf.gfile.Open(bundle_file, 'rb') as f:
    try:
      bundle.ParseFromString(f.read())
    except message.DecodeError as e:
      raise GeneratorBundleParseException(e)
  return bundle
    def testUseMatchingGeneratorId(self):
        bundle = generator_pb2.GeneratorBundle(
            generator_details=generator_pb2.GeneratorDetails(
                id='test_generator'),
            checkpoint_file=[b'foo.ckpt'],
            metagraph_file=b'foo.ckpt.meta')

        SeuenceGenerator(bundle=bundle)

        bundle.generator_details.id = 'blarg'

        with self.assertRaises(sequence_generator.SequenceGeneratorError):
            SeuenceGenerator(bundle=bundle)
    def testSpecifyEitherCheckPointOrBundle(self):
        bundle = generator_pb2.GeneratorBundle(
            generator_details=generator_pb2.GeneratorDetails(
                id='test_generator'),
            checkpoint_file=[b'foo.ckpt'],
            metagraph_file=b'foo.ckpt.meta')

        with self.assertRaises(sequence_generator.SequenceGeneratorError):
            SeuenceGenerator(checkpoint='foo.ckpt', bundle=bundle)
        with self.assertRaises(sequence_generator.SequenceGeneratorError):
            SeuenceGenerator(checkpoint=None, bundle=None)

        SeuenceGenerator(checkpoint='foo.ckpt')
        SeuenceGenerator(bundle=bundle)
Beispiel #5
0
    def create_bundle_file(self, bundle_file, description=None):
        """Writes a generator_pb2.GeneratorBundle file in the specified location.

    Saves the checkpoint, metagraph, and generator id in one file.

    Args:
      bundle_file: Location to write the bundle file.
      description: A short, human-readable text description of the bundle (e.g.,
          training data, hyper parameters, etc.).

    Raises:
      SequenceGeneratorException: if there is an error creating the bundle file.
    """
        if not bundle_file:
            raise SequenceGeneratorException(
                'Bundle file location not specified.')

        self.initialize()

        tempdir = None
        try:
            tempdir = tempfile.mkdtemp()
            checkpoint_filename = os.path.join(tempdir, 'model.ckpt')

            self._write_checkpoint_with_metagraph(checkpoint_filename)

            if not os.path.isfile(checkpoint_filename):
                raise SequenceGeneratorException(
                    'Could not read checkpoint file: %s' %
                    (checkpoint_filename))
            metagraph_filename = checkpoint_filename + '.meta'
            if not os.path.isfile(metagraph_filename):
                raise SequenceGeneratorException(
                    'Could not read metagraph file: %s' % (metagraph_filename))

            bundle = generator_pb2.GeneratorBundle()
            bundle.generator_details.CopyFrom(self.details)
            if description is not None:
                bundle.bundle_details.description = description
            with tf.gfile.Open(checkpoint_filename, 'rb') as f:
                bundle.checkpoint_file.append(f.read())
            with tf.gfile.Open(metagraph_filename, 'rb') as f:
                bundle.metagraph_file = f.read()

            with tf.gfile.Open(bundle_file, 'wb') as f:
                f.write(bundle.SerializeToString())
        finally:
            if tempdir is not None:
                tf.gfile.DeleteRecursively(tempdir)
    def testGetBundleDetails(self):
        # Test with non-bundle generator.
        seq_gen = SeuenceGenerator(checkpoint='foo.ckpt')
        self.assertEqual(None, seq_gen.bundle_details)

        # Test with bundle-based generator.
        bundle_details = generator_pb2.GeneratorBundle.BundleDetails(
            description='bundle of joy')
        bundle = generator_pb2.GeneratorBundle(
            generator_details=generator_pb2.GeneratorDetails(
                id='test_generator'),
            bundle_details=bundle_details,
            checkpoint_file=[b'foo.ckpt'],
            metagraph_file=b'foo.ckpt.meta')
        seq_gen = SeuenceGenerator(bundle=bundle)
        self.assertEqual(bundle_details, seq_gen.bundle_details)
Beispiel #7
0
def main(unused_argv):

    tf.logging.set_verbosity(tf.logging.INFO)

    if not FLAGS.checkpoint_number:
        tf.logging.fatal('--checkpoint_number required')
        return

    model_dict = {'LSTM': LSTMModel, 'LSTMAE': LSTMAE}
    checkpoint_dir = FLAGS.checkpoint_dir

    data_config = models.default_configs['performance_with_meta_128']
    data_config.hparams.parse(FLAGS.hparams)
    encoder_decoder = data_config.encoder_decoder

    layers = FLAGS.layers
    batch_size = 1

    if FLAGS.classifier_weight:
        config = LSTMConfig(encoder_decoder=encoder_decoder,
                            gpu=FLAGS.gpu,
                            layers=layers,
                            batch_size=batch_size,
                            label_classifier_weight=FLAGS.classifier_weight,
                            label_classifier_units=4)
    else:
        config = LSTMConfig(encoder_decoder=encoder_decoder,
                            gpu=FLAGS.gpu,
                            layers=layers,
                            batch_size=batch_size)

    model = model_dict[FLAGS.model](config, 'generate')
    number = FLAGS.checkpoint_number
    bundle_file = checkpoint_dir + 'bundle.mag'
    checkpoint_filename = os.path.join(checkpoint_dir,
                                       'model.ckpt-' + str(number))
    metagraph_filename = os.path.join(checkpoint_dir,
                                      'model.ckpt-' + str(number) + '.meta')

    tf.logging.info(checkpoint_dir)
    with tf.Graph().as_default() as g:

        model.build_graph_fn()
        sess = tf.Session(graph=g)
        saver = tf.train.Saver()

        tf.logging.info('Checkpoint used: %s', checkpoint_filename)
        saver.restore(sess, checkpoint_filename)

        try:
            tempdir = tempfile.mkdtemp()
            checkpoint_filename = os.path.join(tempdir, 'model.ckpt')
            saver = tf.train.Saver(sharded=False,
                                   write_version=tf.train.SaverDef.V1)
            saver.save(sess,
                       checkpoint_filename,
                       meta_graph_suffix='meta',
                       write_meta_graph=True)
            metagraph_filename = checkpoint_filename + '.meta'
            bundle = generator_pb2.GeneratorBundle()

            details = generator_pb2.GeneratorDetails(
                id='performance_with_dynamics',
                description='Performance RNN with dynamics (compact input)')
            bundle.generator_details.CopyFrom(details)

            with tf.gfile.Open(checkpoint_filename, 'rb') as f:
                bundle.checkpoint_file.append(f.read())
            with tf.gfile.Open(metagraph_filename, 'rb') as f:
                bundle.metagraph_file = f.read()

            tf.logging.info('Writing to: ' + bundle_file)
            with tf.gfile.Open(bundle_file, 'wb') as f:
                f.write(bundle.SerializeToString())

        finally:
            if tempdir is not None:
                tf.gfile.DeleteRecursively(tempdir)