예제 #1
0
    def __init__(self,
                 generator_name,
                 num_bars_to_generate,
                 hparams,
                 checkpoint=None,
                 bundle_file=None):
        self._num_bars_to_generate = num_bars_to_generate

        if not checkpoint and not bundle_file:
            raise GeneratorException(
                'No generator checkpoint or bundle location supplied.')
        if (checkpoint or generator_name or hparams) and bundle_file:
            raise GeneratorException(
                'Cannot specify both bundle file and checkpoint, generator_name, '
                'or hparams.')

        bundle = None
        if bundle_file:
            bundle = sequence_generator_bundle.read_bundle_file(bundle_file)
            generator_name = bundle.generator_details.id

        if generator_name not in _GENERATOR_FACTORY_MAP:
            raise GeneratorException('Invalid generator name given: %s',
                                     generator_name)

        generator = _GENERATOR_FACTORY_MAP[generator_name].create_generator(
            checkpoint=checkpoint, bundle=bundle, hparams=hparams)
        generator.initialize()

        self._generator = generator
예제 #2
0
def get_bundle():
    """Returns a generator_pb2.GeneratorBundle object based read from bundle_file.

  Returns:
    Either a generator_pb2.GeneratorBundle or None if the bundle_file flag is
    not set or the save_generator_bundle flag is set.
  """
    if should_save_generator_bundle():
        return None
    bundle_file = get_bundle_file()
    if bundle_file is None:
        return None
    return sequence_generator_bundle.read_bundle_file(bundle_file)