示例#1
0
    def create_bundle_file(self, bundle_file, bundle_description=None):
        """Writes a generator_pb2.GeneratorBundle file in the specified location.

    Saves the checkpoint, metagraph, and generator id in one file.

    Args:
      bundle_file: Location to write the bundle file.
      bundle_description: A short, human-readable string description of this
          bundle.

    Raises:
      SequenceGeneratorError: if there is an error creating the bundle file.
    """
        if not bundle_file:
            raise SequenceGeneratorError('Bundle file location not specified.')
        if not self.details.id:
            raise SequenceGeneratorError(
                'Generator id must be included in GeneratorDetails when creating '
                'a bundle file.')

        if not self.details.description:
            tf.logging.warn(
                'Writing bundle file with no generator description.')
        if not bundle_description:
            tf.logging.warn('Writing bundle file with no bundle description.')

        self.initialize()

        tempdir = None
        try:
            tempdir = tempfile.mkdtemp()
            checkpoint_filename = os.path.join(tempdir, 'model.ckpt')

            self._model.write_checkpoint_with_metagraph(checkpoint_filename)

            if not os.path.isfile(checkpoint_filename):
                raise SequenceGeneratorError(
                    'Could not read checkpoint file: %s' %
                    (checkpoint_filename))
            metagraph_filename = checkpoint_filename + '.meta'
            if not os.path.isfile(metagraph_filename):
                raise SequenceGeneratorError(
                    'Could not read metagraph file: %s' % (metagraph_filename))

            bundle = generator_pb2.GeneratorBundle()
            bundle.generator_details.CopyFrom(self.details)
            if bundle_description:
                bundle.bundle_details.description = bundle_description
            with tf.gfile.Open(checkpoint_filename, 'rb') as f:
                bundle.checkpoint_file.append(f.read())
            with tf.gfile.Open(metagraph_filename, 'rb') as f:
                bundle.metagraph_file = f.read()

            with tf.gfile.Open(bundle_file, 'wb') as f:
                f.write(bundle.SerializeToString())
        finally:
            if tempdir is not None:
                tf.gfile.DeleteRecursively(tempdir)
def read_bundle_file(bundle_file):
    # Read in bundle file.
    bundle = generator_pb2.GeneratorBundle()
    with tf.gfile.Open(bundle_file, 'rb') as f:
        try:
            bundle.ParseFromString(f.read())
        except message.DecodeError as e:
            raise GeneratorBundleParseError(e)
    return bundle
    def testUseMatchingGeneratorId(self):
        bundle = generator_pb2.GeneratorBundle(
            generator_details=generator_pb2.GeneratorDetails(
                id='test_generator'),
            checkpoint_file=[b'foo.ckpt'],
            metagraph_file=b'foo.ckpt.meta')

        SeuenceGenerator(bundle=bundle)

        bundle.generator_details.id = 'blarg'

        with self.assertRaises(sequence_generator.SequenceGeneratorError):
            SeuenceGenerator(bundle=bundle)
    def testSpecifyEitherCheckPointOrBundle(self):
        bundle = generator_pb2.GeneratorBundle(
            generator_details=generator_pb2.GeneratorDetails(
                id='test_generator'),
            checkpoint_file=[b'foo.ckpt'],
            metagraph_file=b'foo.ckpt.meta')

        with self.assertRaises(sequence_generator.SequenceGeneratorError):
            SeuenceGenerator(checkpoint='foo.ckpt', bundle=bundle)
        with self.assertRaises(sequence_generator.SequenceGeneratorError):
            SeuenceGenerator(checkpoint=None, bundle=None)

        SeuenceGenerator(checkpoint='foo.ckpt')
        SeuenceGenerator(bundle=bundle)
    def testGetBundleDetails(self):
        # Test with non-bundle generator.
        seq_gen = SeuenceGenerator(checkpoint='foo.ckpt')
        self.assertEqual(None, seq_gen.bundle_details)

        # Test with bundle-based generator.
        bundle_details = generator_pb2.GeneratorBundle.BundleDetails(
            description='bundle of joy')
        bundle = generator_pb2.GeneratorBundle(
            generator_details=generator_pb2.GeneratorDetails(
                id='test_generator'),
            bundle_details=bundle_details,
            checkpoint_file=[b'foo.ckpt'],
            metagraph_file=b'foo.ckpt.meta')
        seq_gen = SeuenceGenerator(bundle=bundle)
        self.assertEqual(bundle_details, seq_gen.bundle_details)