def create_bundle_file(self, bundle_file, bundle_description=None): """Writes a generator_pb2.GeneratorBundle file in the specified location. Saves the checkpoint, metagraph, and generator id in one file. Args: bundle_file: Location to write the bundle file. bundle_description: A short, human-readable string description of this bundle. Raises: SequenceGeneratorError: if there is an error creating the bundle file. """ if not bundle_file: raise SequenceGeneratorError('Bundle file location not specified.') if not self.details.id: raise SequenceGeneratorError( 'Generator id must be included in GeneratorDetails when creating ' 'a bundle file.') if not self.details.description: tf.logging.warn( 'Writing bundle file with no generator description.') if not bundle_description: tf.logging.warn('Writing bundle file with no bundle description.') self.initialize() tempdir = None try: tempdir = tempfile.mkdtemp() checkpoint_filename = os.path.join(tempdir, 'model.ckpt') self._model.write_checkpoint_with_metagraph(checkpoint_filename) if not os.path.isfile(checkpoint_filename): raise SequenceGeneratorError( 'Could not read checkpoint file: %s' % (checkpoint_filename)) metagraph_filename = checkpoint_filename + '.meta' if not os.path.isfile(metagraph_filename): raise SequenceGeneratorError( 'Could not read metagraph file: %s' % (metagraph_filename)) bundle = generator_pb2.GeneratorBundle() bundle.generator_details.CopyFrom(self.details) if bundle_description: bundle.bundle_details.description = bundle_description with tf.gfile.Open(checkpoint_filename, 'rb') as f: bundle.checkpoint_file.append(f.read()) with tf.gfile.Open(metagraph_filename, 'rb') as f: bundle.metagraph_file = f.read() with tf.gfile.Open(bundle_file, 'wb') as f: f.write(bundle.SerializeToString()) finally: if tempdir is not None: tf.gfile.DeleteRecursively(tempdir)
def read_bundle_file(bundle_file): # Read in bundle file. bundle = generator_pb2.GeneratorBundle() with tf.gfile.Open(bundle_file, 'rb') as f: try: bundle.ParseFromString(f.read()) except message.DecodeError as e: raise GeneratorBundleParseError(e) return bundle
def testUseMatchingGeneratorId(self): bundle = generator_pb2.GeneratorBundle( generator_details=generator_pb2.GeneratorDetails( id='test_generator'), checkpoint_file=[b'foo.ckpt'], metagraph_file=b'foo.ckpt.meta') SeuenceGenerator(bundle=bundle) bundle.generator_details.id = 'blarg' with self.assertRaises(sequence_generator.SequenceGeneratorError): SeuenceGenerator(bundle=bundle)
def testSpecifyEitherCheckPointOrBundle(self): bundle = generator_pb2.GeneratorBundle( generator_details=generator_pb2.GeneratorDetails( id='test_generator'), checkpoint_file=[b'foo.ckpt'], metagraph_file=b'foo.ckpt.meta') with self.assertRaises(sequence_generator.SequenceGeneratorError): SeuenceGenerator(checkpoint='foo.ckpt', bundle=bundle) with self.assertRaises(sequence_generator.SequenceGeneratorError): SeuenceGenerator(checkpoint=None, bundle=None) SeuenceGenerator(checkpoint='foo.ckpt') SeuenceGenerator(bundle=bundle)
def testGetBundleDetails(self): # Test with non-bundle generator. seq_gen = SeuenceGenerator(checkpoint='foo.ckpt') self.assertEqual(None, seq_gen.bundle_details) # Test with bundle-based generator. bundle_details = generator_pb2.GeneratorBundle.BundleDetails( description='bundle of joy') bundle = generator_pb2.GeneratorBundle( generator_details=generator_pb2.GeneratorDetails( id='test_generator'), bundle_details=bundle_details, checkpoint_file=[b'foo.ckpt'], metagraph_file=b'foo.ckpt.meta') seq_gen = SeuenceGenerator(bundle=bundle) self.assertEqual(bundle_details, seq_gen.bundle_details)