def main(unused_argv): tf.logging.set_verbosity(mg.FLAGS.log) if not mg.FLAGS.bundle_file: tf.logging.fatal("--bundle_file is required") return model_path = os.path.join(tgt.MODEL_DIR, mg.FLAGS.bundle_file + ".mag") hparams_path = os.path.join(tgt.MODEL_DIR, mg.FLAGS.bundle_file + ".hparams") if tf.gfile.Exists(hparams_path): tf.logging.info("Model parameter is read from file: %s", hparams_path) with tf.gfile.Open(hparams_path) as f: config, hparams = f.readline().split("\t") melody_rnn_config_flags.FLAGS.config = config melody_rnn_config_flags.FLAGS.hparams = hparams config = melody_rnn_config_flags.config_from_flags() elif mg.FLAGS.bundle_file in melody_rnn_model.default_configs: tf.logging.info("Model parameter is set by default") config = melody_rnn_model.default_configs[mg.FLAGS.bundle_file] else: tf.logging.info("Model parameter is read from arguments: %s", mg.FLAGS.hparams) config = melody_rnn_config_flags.config_from_flags() generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( model=melody_rnn_model.MelodyRnnModel(config), details=config.details, steps_per_quarter=config.steps_per_quarter, bundle=magenta.music.read_bundle_file(model_path)) mg.FLAGS.output_dir = tgt.GENERATED_DIR mg.run_with_flags(generator)
def configure_sequence_generator(trained_model_name, bundle_file): """Configure and return a trained sequence generator. Additional models will be supported in the future, and this configuration tool will become more generalized. Args: trained_model_name: name of trained model bundle_file: filename of magenta bundle file (.mag) Returns: sequence_generator: a sequence generator to execute """ bundle = magenta.music.read_bundle_file(bundle_file) # Model and generator selection if trained_model_name == 'melody_rnn_generator': melody_rnn_config = melody_rnn_model.default_configs['basic_rnn'] model = melody_rnn_model.MelodyRnnModel(melody_rnn_config) sequence_generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( model=model, details=melody_rnn_config.details, steps_per_quarter=melody_rnn_config.steps_per_quarter, bundle=bundle) else: print("Model {0} not supported.".format(trained_model_name)) return return sequence_generator
def synthesize(midi_file, model='basic', num_steps=2000, max_primer_notes=32, temperature=1.0, beam_size=1, branch_factor=1, steps_per_quarter=16, **kwargs): """Summary Parameters ---------- midi_file : TYPE Description model : str, optional Description num_steps : int, optional Description max_primer_notes : int, optional Description temperature : float, optional Description beam_size : int, optional Description branch_factor : int, optional Description steps_per_quarter : int, optional Description **kwargs Description """ config = melody_rnn_model.default_configs['{}_rnn'.format(model)] bundle_file = '{}_rnn.mag'.format(model) generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( model=melody_rnn_model.MelodyRnnModel(config), details=config.details, steps_per_quarter=steps_per_quarter, bundle=magenta.music.read_bundle_file(bundle_file)) # Get a protobuf of the MIDI file seq, qpm = parse_midi_file(midi_file, **kwargs) opt = generator_pb2.GeneratorOptions() seconds_per_step = 60.0 / qpm / steps_per_quarter total_seconds = num_steps * seconds_per_step last_end_time = max(n.end_time for n in seq.notes) opt.generate_sections.add(start_time=last_end_time + seconds_per_step, end_time=total_seconds) opt.args['temperature'].float_value = temperature opt.args['beam_size'].int_value = beam_size opt.args['branch_factor'].int_value = branch_factor opt.args['steps_per_iteration'].int_value = 1 print(opt) generated = generator.generate(seq, opt) fname = 'primer.mid' magenta.music.sequence_proto_to_midi_file(seq, fname) fname = 'synthesis.mid' magenta.music.sequence_proto_to_midi_file(generated, fname)
def _generate_melody(self): melody_config_id = self.melody_bundle.generator_details.id melody_config = melody_rnn_model.default_configs[melody_config_id] generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( model=melody_rnn_model.MelodyRnnModel(melody_config), details=melody_config.details, steps_per_quarter=melody_config.steps_per_quarter, checkpoint=melody_rnn_generate.get_checkpoint(), bundle=self.melody_bundle) generator_options = generator_pb2.GeneratorOptions() generator_options.args['temperature'].float_value = self.temperature generator_options.args['beam_size'].int_value = 1 generator_options.args['branch_factor'].int_value = 1 generator_options.args['steps_per_iteration'].int_value = 1 primer_melody = magenta.music.Melody(self.accumulated_primer_melody) qpm = magenta.music.DEFAULT_QUARTERS_PER_MINUTE primer_sequence = primer_melody.to_sequence(qpm=qpm) seconds_per_step = 60.0 / qpm / generator.steps_per_quarter # Set the start time to begin on the next step after the last note ends. last_end_time = (max( n.end_time for n in primer_sequence.notes) if primer_sequence.notes else 0) melody_total_seconds = last_end_time * 3 generate_section = generator_options.generate_sections.add( start_time=last_end_time + seconds_per_step, end_time=melody_total_seconds) generated_sequence = generator.generate(primer_sequence, generator_options) self.generated_melody = [n.pitch for n in generated_sequence.notes] # Get rid of primer melody. self.generated_melody = self.generated_melody[ len(self.accumulated_primer_melody):] # Make sure generated melody is not too long. self.generated_melody = self.generated_melody[:self.max_robot_length] self.accumulated_primer_melody = []
def main(unused_argv): tf.logging.set_verbosity(FLAGS.log) train_dir = os.path.join(tgt.MODEL_DIR, "logdir/train") hparams_path = os.path.join(train_dir, "hparams") with tf.gfile.Open(hparams_path) as f: config, hparams = f.readline().split("\t") melody_rnn_config_flags.FLAGS.config = config melody_rnn_config_flags.FLAGS.hparams = hparams config = melody_rnn_config_flags.config_from_flags() generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( model=melody_rnn_model.MelodyRnnModel(config), details=config.details, steps_per_quarter=config.steps_per_quarter, checkpoint=train_dir) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") bundle_file = (FLAGS.bundle_file if FLAGS.bundle_file else "magenta_" + timestamp) + ".mag" bundle_path = os.path.join(tgt.MODEL_DIR, bundle_file) if FLAGS.bundle_description is None: tf.logging.warning('No bundle description provided.') tf.logging.info('Saving generator bundle to %s', bundle_path) generator.create_bundle_file(bundle_path, FLAGS.bundle_description) tf.gfile.Copy(hparams_path, os.path.join(tgt.MODEL_DIR, bundle_file.replace(".mag", ".hparams")), overwrite=True)
def main(unused_argv): """Saves bundle or runs generator based on flags.""" tf.logging.set_verbosity(FLAGS.log) bundle = get_bundle() if bundle: config_id = bundle.generator_details.id config = melody_rnn_model.default_configs[config_id] config.hparams.parse(FLAGS.hparams) else: config = melody_rnn_config_flags.config_from_flags() generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( model=melody_rnn_model.MelodyRnnModel(config), details=config.details, steps_per_quarter=config.steps_per_quarter, checkpoint=get_checkpoint(), bundle=bundle) if FLAGS.save_generator_bundle: bundle_filename = os.path.expanduser(FLAGS.bundle_file) if FLAGS.bundle_description is None: tf.logging.warning('No bundle description provided.') tf.logging.info('Saving generator bundle to %s', bundle_filename) generator.create_bundle_file(bundle_filename, FLAGS.bundle_description) else: run_with_flags(generator)
def create_midi_test(self, midi_data): BUNDLE_NAME = 'attention_rnn' config = magenta.models.melody_rnn.melody_rnn_model.default_configs[ BUNDLE_NAME] bundle_file = magenta.music.read_bundle_file( os.path.abspath(BUNDLE_NAME + '.mag')) steps_per_quarter = 4 generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( model=melody_rnn_model.MelodyRnnModel(config), details=config.details, steps_per_quarter=steps_per_quarter, # checkpoint=get_checkpoint(), bundle=bundle_file) qpm = 120 generator_options = generator_pb2.GeneratorOptions() seconds_per_step = 60.0 / qpm / generator.steps_per_quarter total_seconds = 3 primer_sequence = magenta.music.midi_to_sequence_proto(midi_data) if primer_sequence.tempos and primer_sequence.tempos[0].qpm: qpm = primer_sequence.tempos[0].qpm if primer_sequence: input_sequence = primer_sequence # Set the start time to begin on the next step after the last note ends. last_end_time = (max(n.end_time for n in primer_sequence.notes) if primer_sequence.notes else 0) # generate_section = generator_options.generate_sections.add( # start_time=last_end_time + seconds_per_step, # end_time=total_seconds) input_sequence = music_pb2.NoteSequence() input_sequence.tempos.add().qpm = qpm generate_section = generator_options.generate_sections.add( start_time=0, end_time=total_seconds) # generator_options.args['temperature'].float_value = FLAGS.temperature # generator_options.args['beam_size'].int_value = FLAGS.beam_size # generator_options.args['branch_factor'].int_value = FLAGS.branch_factor # generator_options.args[ # 'steps_per_iteration'].int_value = FLAGS.steps_per_iteration # tf.logging.debug('input_sequence: %s', input_sequence) # tf.logging.debug('generator_options: %s', generator_options) # Make the generate request num_outputs times and save the output as midi # files. date_and_time = time.strftime('%Y-%m-%d_%H%M%S') digits = 1 generated_sequence = generator.generate(input_sequence, generator_options) midi_filename = '%s_%s.mid' % (date_and_time, str(1).zfill(digits)) midi_path = os.path.join("./midifile", midi_filename) magenta.music.sequence_proto_to_midi_file(generated_sequence, midi_path) tf.logging.info('Wrote %d MIDI files to %s', "1", "midi folder")
def create_midi(self, midi_data): BUNDLE_NAME = 'attention_rnn' config = magenta.models.melody_rnn.melody_rnn_model.default_configs[ BUNDLE_NAME] bundle_file = magenta.music.read_bundle_file( os.path.abspath(BUNDLE_NAME + '.mag')) steps_per_quarter = 4 generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( model=melody_rnn_model.MelodyRnnModel(config), details=config.details, steps_per_quarter=steps_per_quarter, # checkpoint=get_checkpoint(), bundle=bundle_file) qpm = 120 generator_options = generator_pb2.GeneratorOptions() seconds_per_step = 60.0 / qpm / generator.steps_per_quarter total_seconds = 3 primer_sequence = magenta.music.midi_to_sequence_proto(midi_data) if primer_sequence.tempos and primer_sequence.tempos[0].qpm: qpm = primer_sequence.tempos[0].qpm if primer_sequence: input_sequence = primer_sequence # Set the start time to begin on the next step after the last note ends. last_end_time = (max(n.end_time for n in primer_sequence.notes) if primer_sequence.notes else 0) # generate_section = generator_options.generate_sections.add( # start_time=last_end_time + seconds_per_step, # end_time=total_seconds) input_sequence = music_pb2.NoteSequence() input_sequence.tempos.add().qpm = qpm generate_section = generator_options.generate_sections.add( start_time=0, end_time=total_seconds) # generator_options.args['temperature'].float_value = FLAGS.temperature # generator_options.args['beam_size'].int_value = FLAGS.beam_size # generator_options.args['branch_factor'].int_value = FLAGS.branch_factor # generator_options.args[ # 'steps_per_iteration'].int_value = FLAGS.steps_per_iteration # tf.logging.debug('input_sequence: %s', input_sequence) # tf.logging.debug('generator_options: %s', generator_options) generated_sequence = generator.generate(input_sequence, generator_options) output = tempfile.NamedTemporaryFile() magenta.music.midi_io.sequence_proto_to_midi_file( generated_sequence, output.name) output.seek(0) return output
def main(unused_argv): """Saves bundle or runs generator based on flags.""" generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( melody_rnn_config.config_from_flags(), FLAGS.steps_per_quarter, get_checkpoint(), get_bundle()) if FLAGS.save_generator_bundle: bundle_filename = os.path.expanduser(FLAGS.bundle_file) tf.logging.info('Saving generator bundle to %s', bundle_filename) generator.create_bundle_file(bundle_filename) else: run_with_flags(generator)
def load_1(self, bundle_name): bundle_name = str(bundle_name) config = magenta.models.melody_rnn.melody_rnn_model.default_configs[bundle_name] bundle_file = magenta.music.read_bundle_file(os.path.join(script_dir, bundle_name+'.mag')) steps_per_quarter = 4 self.generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( model = melody_rnn_model.MelodyRnnModel(config), details = config.details, steps_per_quarter = steps_per_quarter, bundle = bundle_file ) self._outlet(1, "loaded")
def __init__(self, bundle_path: str): """Initialize model from bundle. bundle_path (str): Path to the MelodyRnnSequenceGenerator to use for generation. """ bundle_file = os.path.expanduser(bundle_path) bundle = sequence_generator_bundle.read_bundle_file(bundle_file) config_id = bundle.generator_details.id config = melody_rnn_model.default_configs[config_id] self.generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( model=melody_rnn_model.MelodyRnnModel(config), details=config.details, steps_per_quarter=config.steps_per_quarter, checkpoint=None, bundle=bundle)
def test_create_midi_default(self): BUNDLE_NAME = 'attention_rnn' config = magenta.models.melody_rnn.melody_rnn_model.default_configs[ BUNDLE_NAME] bundle_file = magenta.music.read_bundle_file( os.path.abspath(BUNDLE_NAME + '.mag')) steps_per_quarter = 4 generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( model=melody_rnn_model.MelodyRnnModel(config), details=config.details, steps_per_quarter=steps_per_quarter, # checkpoint=get_checkpoint(), bundle=bundle_file) qpm = 120 generator_options = generator_pb2.GeneratorOptions() total_seconds = 4.0 primer_sequence = magenta.music.midi_file_to_sequence_proto( '180827_02_midi.mid') if primer_sequence.tempos and primer_sequence.tempos[0].qpm: qpm = primer_sequence.tempos[0].qpm generator_options = generator_pb2.GeneratorOptions() # Set the start time to begin on the next step after the last note ends. last_end_time = (max( n.end_time for n in primer_sequence.notes) if primer_sequence.notes else 0) generator_options.generate_sections.add(start_time=last_end_time + _steps_to_seconds(1, qpm), end_time=total_seconds) generated_sequence = generator.generate(primer_sequence, generator_options) magenta.music.sequence_proto_to_midi_file(generated_sequence, 'new2.mid') tf.logging.info('Wrote %d MIDI files to %s', "1", "midi folder")
def __init__(self, config, checkpoint=None, bundle_filename=None, steps_per_quarter=4): """Initialize the MelodyRnnModel. Args: config: A MelodyRnnConfig containing the MelodyEncoderDecoder and HParams use. checkpoint: Where to search for the most recent model checkpoint. bundle_filename: The filename of a generator_pb2.GeneratorBundle object that includes both the model checkpoint and metagraph. steps_per_quarter: What precision to use when quantizing the melody. How many steps per quarter note. """ if bundle_filename is not None: bundle = magenta.music.read_bundle_file(bundle_filename) else: bundle = None self._generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( config, steps_per_quarter, checkpoint, bundle)
def make_generator(bundle_name): model_path = os.path.join(os.path.dirname(__file__), "../models/" + bundle_name + ".mag") hparams_path = os.path.join(os.path.dirname(__file__), "../models/" + bundle_name + ".hparams") if tf.gfile.Exists(hparams_path): with tf.gfile.Open(hparams_path) as f: config, hparams = f.readline().split("\t") melody_rnn_config_flags.FLAGS.config = config melody_rnn_config_flags.FLAGS.hparams = hparams config = melody_rnn_config_flags.config_from_flags() elif bundle_name in melody_rnn_model.default_configs: config = melody_rnn_model.default_configs[bundle_name] else: raise Exception("can not define the model config.") generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( model=melody_rnn_model.MelodyRnnModel(config), details=config.details, steps_per_quarter=STEPS_PER_QUARTER, bundle=magenta.music.read_bundle_file(model_path)) return generator
from magenta.models.melody_rnn import melody_rnn_model from magenta.models.melody_rnn import melody_rnn_sequence_generator from magenta.protobuf import generator_pb2 from magenta.protobuf import music_pb2 BUNDLE_NAME = 'attention_rnn' config = magenta.models.melody_rnn.melody_rnn_model.default_configs[ BUNDLE_NAME] bundle_file = magenta.music.read_bundle_file( os.path.abspath(BUNDLE_NAME + '.mag')) steps_per_quarter = 4 generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( model=melody_rnn_model.MelodyRnnModel(config), details=config.details, steps_per_quarter=steps_per_quarter, bundle=bundle_file) def _steps_to_seconds(steps, qpm): return steps * 60.0 / qpm / steps_per_quarter def generate_midi(midi_data, total_seconds=10): primer_sequence = magenta.music.midi_io.midi_to_sequence_proto(midi_data) # predict the tempo if len(primer_sequence.notes) > 4: estimated_tempo = midi_data.estimate_tempo() if estimated_tempo > 240:
def call_melody_rnn(primer_melody): flist = tf.app.flags.FLAGS._flags() klist = [] for i in flist: klist.append(i) for k in klist: tf.app.flags.FLAGS.__delattr__(k) FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string( 'run_dir', None, 'Path to the directory where the latest checkpoint will be loaded from.' ) tf.app.flags.DEFINE_string( 'checkpoint_file', None, 'Path to the checkpoint file. run_dir will take priority over this flag.' ) tf.app.flags.DEFINE_string( 'bundle_file', "/Users/yuhaomao/Downloads/lookback_rnn.mag", 'Path to the bundle file. If specified, this will take priority over ' 'run_dir and checkpoint_file, unless save_generator_bundle is True, in ' 'which case both this flag and either run_dir or checkpoint_file are ' 'required') tf.app.flags.DEFINE_boolean( 'save_generator_bundle', False, 'If true, instead of generating a sequence, will save this generator as a ' 'bundle file in the location specified by the bundle_file flag') tf.app.flags.DEFINE_string( 'bundle_description', None, 'A short, human-readable text description of the bundle (e.g., training ' 'data, hyper parameters, etc.).') tf.app.flags.DEFINE_string( 'output_dir', '/tmp/melody_rnn/generated', 'The directory where MIDI files will be saved to.') tf.app.flags.DEFINE_integer( 'num_outputs', 1, 'The number of melodies to generate. One MIDI file will be created for ' 'each.') tf.app.flags.DEFINE_integer( 'num_steps', 16, 'The total number of steps the generated melodies should be, priming ' 'melody length + generated steps. Each step is a 16th of a bar.') tf.app.flags.DEFINE_string( 'primer_melody', primer_melody, 'A string representation of a Python list of ' 'magenta.music.Melody event values. For example: ' '"[60, -2, 60, -2, 67, -2, 67, -2]". If specified, this melody will be ' 'used as the priming melody. If a priming melody is not specified, ' 'melodies will be generated from scratch.') tf.app.flags.DEFINE_string( 'primer_midi', '', 'The path to a MIDI file containing a melody that will be used as a ' 'priming melody. If a primer melody is not specified, melodies will be ' 'generated from scratch.') tf.app.flags.DEFINE_float( 'qpm', 60, 'The quarters per minute to play generated output at. If a primer MIDI is ' 'given, the qpm from that will override this flag. If qpm is None, qpm ' 'will default to 120.') tf.app.flags.DEFINE_float( 'temperature', 1.0, 'The randomness of the generated melodies. 1.0 uses the unaltered softmax ' 'probabilities, greater than 1.0 makes melodies more random, less than 1.0 ' 'makes melodies less random.') tf.app.flags.DEFINE_integer( 'beam_size', 1, 'The beam size to use for beam search when generating melodies.') tf.app.flags.DEFINE_integer( 'branch_factor', 1, 'The branch factor to use for beam search when generating melodies.') tf.app.flags.DEFINE_integer( 'steps_per_iteration', 1, 'The number of melody steps to take per beam search iteration.') tf.app.flags.DEFINE_string( 'log', 'INFO', 'The threshold for what messages will be logged DEBUG, INFO, WARN, ERROR, ' 'or FATAL.') tf.app.flags.DEFINE_string( 'hparams', "", 'Hyperparameter overrides, ' 'represented as a string containing comma-separated ' 'hparam_name=value pairs.') tf.logging.set_verbosity(FLAGS.log) bundle = get_bundle() if bundle: config_id = bundle.generator_details.id config = melody_rnn_model.default_configs[config_id] config.hparams.parse(FLAGS.hparams) else: config = melody_rnn_config_flags.config_from_flags() # save = Saver() # basemodel = BaseModel() # session = Session() # basemodel._build_graph_for_generation() # basemodel.initialize_with_checkpoint(checkpoint_file="/Users/yuhaomao/Downloads/lookback_rnn.mag") # eventsequencernnmodel = EventSequenceRnnModel(config) generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( model=melody_rnn_model.MelodyRnnModel(config), details=config.details, steps_per_quarter=config.steps_per_quarter, checkpoint=get_checkpoint(), bundle=bundle) primer_midi = None if FLAGS.primer_midi: primer_midi = os.path.expanduser(FLAGS.primer_midi) primer_sequence = None qpm = FLAGS.qpm if FLAGS.qpm else magenta.music.DEFAULT_QUARTERS_PER_MINUTE if FLAGS.primer_melody: primer_melody = magenta.music.Melody( ast.literal_eval(FLAGS.primer_melody)) primer_sequence = primer_melody.to_sequence(qpm=qpm) elif primer_midi: primer_sequence = magenta.music.midi_file_to_sequence_proto( primer_midi) if primer_sequence.tempos and primer_sequence.tempos[0].qpm: qpm = primer_sequence.tempos[0].qpm else: tf.logging.warning( 'No priming sequence specified. Defaulting to a single middle C.') primer_melody = magenta.music.Melody([60]) primer_sequence = primer_melody.to_sequence(qpm=qpm) seconds_per_step = 60.0 / qpm / generator.steps_per_quarter total_seconds = FLAGS.num_steps * seconds_per_step generator_options = generator_pb2.GeneratorOptions() if primer_sequence: input_sequence = primer_sequence # Set the start time to begin on the next step after the last note ends. if primer_sequence.notes: last_end_time = max(n.end_time for n in primer_sequence.notes) else: last_end_time = 0 generate_section = generator_options.generate_sections.add( start_time=last_end_time + seconds_per_step, end_time=total_seconds) if generate_section.start_time >= generate_section.end_time: tf.logging.fatal( 'Priming sequence is longer than the total number of steps ' 'requested: Priming sequence length: %s, Generation length ' 'requested: %s', generate_section.start_time, total_seconds) return else: input_sequence = music_pb2.NoteSequence() input_sequence.tempos.add().qpm = qpm generate_section = generator_options.generate_sections.add( start_time=0, end_time=total_seconds) generator_options.args['temperature'].float_value = FLAGS.temperature generator_options.args['beam_size'].int_value = FLAGS.beam_size generator_options.args['branch_factor'].int_value = FLAGS.branch_factor generator_options.args[ 'steps_per_iteration'].int_value = FLAGS.steps_per_iteration tf.logging.debug('input_sequence: %s', input_sequence) tf.logging.debug('generator_options: %s', generator_options) # basesequencegenerator = BaseSequenceGenerator( # model=melody_rnn_model.MelodyRnnModel(config), # details=config.details, # checkpoint=get_checkpoint(), # bundle=bundle # ) melodyrnnsequencegenerator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( model=melody_rnn_model.MelodyRnnModel(config), details=config.details, steps_per_quarter=config.steps_per_quarter, checkpoint=get_checkpoint(), bundle=bundle) melodyrnnsequencegenerator._generate(input_sequence, generator_options)
"""Get the training dir or checkpoint path to be used by the model.""" if ((FLAGS.run_dir or FLAGS.checkpoint_file) and FLAGS.bundle_file and not FLAGS.save_generator_bundle): raise magenta.music.SequenceGeneratorError( 'Cannot specify both bundle_file and run_dir or checkpoint_file') if FLAGS.run_dir: train_dir = os.path.join(os.path.expanduser(FLAGS.run_dir), 'train') return train_dir elif FLAGS.checkpoint_file: return os.path.expanduser(FLAGS.checkpoint_file) else: return None tf.logging.set_verbosity(FLAGS.log) bundle = get_bundle() if bundle: config_id = bundle.generator_details.id config = melody_rnn_model.default_configs[config_id] config.hparams.parse(FLAGS.hparams) generator = melody_rnn_sequence_generator.MelodyRnnSequenceGenerator( model=melody_rnn_model.MelodyRnnModel(config), details=config.details, steps_per_quarter=config.steps_per_quarter, checkpoint=get_checkpoint(), bundle=bundle) print("11111") print(generator) print(type(generator))