def create_bundle_file(self, bundle_file, bundle_description=None): """Writes a generator_pb2.GeneratorBundle file in the specified location. Saves the checkpoint, metagraph, and generator id in one file. Args: bundle_file: Location to write the bundle file. bundle_description: A short, human-readable string description of this bundle. Raises: SequenceGeneratorException: if there is an error creating the bundle file. """ if not bundle_file: raise SequenceGeneratorException( 'Bundle file location not specified.') if not self.details.id: raise SequenceGeneratorException( 'Generator id must be included in GeneratorDetails when creating ' 'a bundle file.') if not self.details.description: tf.logging.warn( 'Writing bundle file with no generator description.') if not bundle_description: tf.logging.warn('Writing bundle file with no bundle description.') self.initialize() tempdir = None try: tempdir = tempfile.mkdtemp() checkpoint_filename = os.path.join(tempdir, 'model.ckpt') self._model.write_checkpoint_with_metagraph(checkpoint_filename) if not os.path.isfile(checkpoint_filename): raise SequenceGeneratorException( 'Could not read checkpoint file: %s' % (checkpoint_filename)) metagraph_filename = checkpoint_filename + '.meta' if not os.path.isfile(metagraph_filename): raise SequenceGeneratorException( 'Could not read metagraph file: %s' % (metagraph_filename)) bundle = generator_pb2.GeneratorBundle() bundle.generator_details.CopyFrom(self.details) if bundle_description: bundle.bundle_details.description = bundle_description with tf.gfile.Open(checkpoint_filename, 'rb') as f: bundle.checkpoint_file.append(f.read()) with tf.gfile.Open(metagraph_filename, 'rb') as f: bundle.metagraph_file = f.read() with tf.gfile.Open(bundle_file, 'wb') as f: f.write(bundle.SerializeToString()) finally: if tempdir is not None: tf.gfile.DeleteRecursively(tempdir)
def read_bundle_file(bundle_file): # Read in bundle file. bundle = generator_pb2.GeneratorBundle() with tf.gfile.Open(bundle_file, 'rb') as f: try: bundle.ParseFromString(f.read()) except message.DecodeError as e: raise GeneratorBundleParseException(e) return bundle
def testUseMatchingGeneratorId(self): bundle = generator_pb2.GeneratorBundle( generator_details=generator_pb2.GeneratorDetails( id='test_generator'), checkpoint_file=[b'foo.ckpt'], metagraph_file=b'foo.ckpt.meta') SeuenceGenerator(bundle=bundle) bundle.generator_details.id = 'blarg' with self.assertRaises(sequence_generator.SequenceGeneratorError): SeuenceGenerator(bundle=bundle)
def testSpecifyEitherCheckPointOrBundle(self): bundle = generator_pb2.GeneratorBundle( generator_details=generator_pb2.GeneratorDetails( id='test_generator'), checkpoint_file=[b'foo.ckpt'], metagraph_file=b'foo.ckpt.meta') with self.assertRaises(sequence_generator.SequenceGeneratorError): SeuenceGenerator(checkpoint='foo.ckpt', bundle=bundle) with self.assertRaises(sequence_generator.SequenceGeneratorError): SeuenceGenerator(checkpoint=None, bundle=None) SeuenceGenerator(checkpoint='foo.ckpt') SeuenceGenerator(bundle=bundle)
def create_bundle_file(self, bundle_file, description=None): """Writes a generator_pb2.GeneratorBundle file in the specified location. Saves the checkpoint, metagraph, and generator id in one file. Args: bundle_file: Location to write the bundle file. description: A short, human-readable text description of the bundle (e.g., training data, hyper parameters, etc.). Raises: SequenceGeneratorException: if there is an error creating the bundle file. """ if not bundle_file: raise SequenceGeneratorException( 'Bundle file location not specified.') self.initialize() tempdir = None try: tempdir = tempfile.mkdtemp() checkpoint_filename = os.path.join(tempdir, 'model.ckpt') self._write_checkpoint_with_metagraph(checkpoint_filename) if not os.path.isfile(checkpoint_filename): raise SequenceGeneratorException( 'Could not read checkpoint file: %s' % (checkpoint_filename)) metagraph_filename = checkpoint_filename + '.meta' if not os.path.isfile(metagraph_filename): raise SequenceGeneratorException( 'Could not read metagraph file: %s' % (metagraph_filename)) bundle = generator_pb2.GeneratorBundle() bundle.generator_details.CopyFrom(self.details) if description is not None: bundle.bundle_details.description = description with tf.gfile.Open(checkpoint_filename, 'rb') as f: bundle.checkpoint_file.append(f.read()) with tf.gfile.Open(metagraph_filename, 'rb') as f: bundle.metagraph_file = f.read() with tf.gfile.Open(bundle_file, 'wb') as f: f.write(bundle.SerializeToString()) finally: if tempdir is not None: tf.gfile.DeleteRecursively(tempdir)
def testGetBundleDetails(self): # Test with non-bundle generator. seq_gen = SeuenceGenerator(checkpoint='foo.ckpt') self.assertEqual(None, seq_gen.bundle_details) # Test with bundle-based generator. bundle_details = generator_pb2.GeneratorBundle.BundleDetails( description='bundle of joy') bundle = generator_pb2.GeneratorBundle( generator_details=generator_pb2.GeneratorDetails( id='test_generator'), bundle_details=bundle_details, checkpoint_file=[b'foo.ckpt'], metagraph_file=b'foo.ckpt.meta') seq_gen = SeuenceGenerator(bundle=bundle) self.assertEqual(bundle_details, seq_gen.bundle_details)
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) if not FLAGS.checkpoint_number: tf.logging.fatal('--checkpoint_number required') return model_dict = {'LSTM': LSTMModel, 'LSTMAE': LSTMAE} checkpoint_dir = FLAGS.checkpoint_dir data_config = models.default_configs['performance_with_meta_128'] data_config.hparams.parse(FLAGS.hparams) encoder_decoder = data_config.encoder_decoder layers = FLAGS.layers batch_size = 1 if FLAGS.classifier_weight: config = LSTMConfig(encoder_decoder=encoder_decoder, gpu=FLAGS.gpu, layers=layers, batch_size=batch_size, label_classifier_weight=FLAGS.classifier_weight, label_classifier_units=4) else: config = LSTMConfig(encoder_decoder=encoder_decoder, gpu=FLAGS.gpu, layers=layers, batch_size=batch_size) model = model_dict[FLAGS.model](config, 'generate') number = FLAGS.checkpoint_number bundle_file = checkpoint_dir + 'bundle.mag' checkpoint_filename = os.path.join(checkpoint_dir, 'model.ckpt-' + str(number)) metagraph_filename = os.path.join(checkpoint_dir, 'model.ckpt-' + str(number) + '.meta') tf.logging.info(checkpoint_dir) with tf.Graph().as_default() as g: model.build_graph_fn() sess = tf.Session(graph=g) saver = tf.train.Saver() tf.logging.info('Checkpoint used: %s', checkpoint_filename) saver.restore(sess, checkpoint_filename) try: tempdir = tempfile.mkdtemp() checkpoint_filename = os.path.join(tempdir, 'model.ckpt') saver = tf.train.Saver(sharded=False, write_version=tf.train.SaverDef.V1) saver.save(sess, checkpoint_filename, meta_graph_suffix='meta', write_meta_graph=True) metagraph_filename = checkpoint_filename + '.meta' bundle = generator_pb2.GeneratorBundle() details = generator_pb2.GeneratorDetails( id='performance_with_dynamics', description='Performance RNN with dynamics (compact input)') bundle.generator_details.CopyFrom(details) with tf.gfile.Open(checkpoint_filename, 'rb') as f: bundle.checkpoint_file.append(f.read()) with tf.gfile.Open(metagraph_filename, 'rb') as f: bundle.metagraph_file = f.read() tf.logging.info('Writing to: ' + bundle_file) with tf.gfile.Open(bundle_file, 'wb') as f: f.write(bundle.SerializeToString()) finally: if tempdir is not None: tf.gfile.DeleteRecursively(tempdir)