예제 #1
0
class VAEDecoder():
    def __init__(self,
                 model_path=None,
                 config_name=None,
                 output_dir=None,
                 max_batch_size=8,
                 config_map=configs.CONFIG_MAP):
        if config_name not in config_map:
            raise ValueError('Invalid config name: %s' % config_name)
        self.config = config_map[config_name]
        self.config.data_converter.max_tensors_per_item = None
        self.output_dir = os.path.expanduser(output_dir)

        checkpoint_path = os.path.expanduser(model_path)

        self.model = TrainedModel(self.config,
                                  batch_size=max_batch_size,
                                  checkpoint_dir_or_path=checkpoint_path)

        self.z_size = self.model._config.hparams.z_size

    def decode(self, z: np.ndarray):
        if len(z.shape) == 2 and z.shape[1] == self.z_size:
            # z is a single latent space
            midi = self._decode(z)
            return midi
        else:
            raise ValueError("z has inappropriate shape: {0}.\n"
                             "Should be: (n_samples, {0})".format(
                                 ", ".join(z.shape), self.z_size))

    def _decode(self, z):
        return self.model.decode(z, length=self.config.hparams.max_seq_len)

    def _write(self, midi):
        date_and_time = time.strftime('%Y-%m-%d_%H%M%S')
        basename = os.path.join(
            self.output_dir, '{}-*-{:03d}.mid'.format(date_and_time,
                                                      len(midi)))
        for i, m in enumerate(midi):
            mm.sequence_proto_to_midi_file(
                m, basename.replace('*', '{:03d}'.format(i)))
예제 #2
0
def run(config_map):
    """Load model params, save config file and start trainer.

  Args:
    config_map: Dictionary mapping configuration name to Config object.

  Raises:
    ValueError: if required flags are missing or invalid.
  """
    date_and_time = time.strftime('%Y-%m-%d_%H%M%S')

    if FLAGS.run_dir is None == FLAGS.checkpoint_file is None:
        raise ValueError(
            'Exactly one of `--run_dir` or `--checkpoint_file` must be specified.'
        )
    if FLAGS.output_dir is None:
        raise ValueError('`--output_dir` is required.')
    tf.gfile.MakeDirs(FLAGS.output_dir)
    if FLAGS.mode != 'sample' and FLAGS.mode != 'interpolate' and FLAGS.mode != 'isample':
        raise ValueError('Invalid value for `--mode`: %s' % FLAGS.mode)

    if FLAGS.config not in config_map:
        raise ValueError('Invalid config name: %s' % FLAGS.config)
    config = config_map[FLAGS.config]
    config.data_converter.max_tensors_per_item = None

    if FLAGS.mode == 'interpolate':
        if FLAGS.input_midi_1 is None or FLAGS.input_midi_2 is None:
            raise ValueError(
                '`--input_midi_1` and `--input_midi_2` must be specified in '
                '`interpolate` mode.')
        input_midi_1 = os.path.expanduser(FLAGS.input_midi_1)
        input_midi_2 = os.path.expanduser(FLAGS.input_midi_2)
        if not os.path.exists(input_midi_1):
            raise ValueError('Input MIDI 1 not found: %s' % FLAGS.input_midi_1)
        if not os.path.exists(input_midi_2):
            raise ValueError('Input MIDI 2 not found: %s' % FLAGS.input_midi_2)
        input_1 = mm.midi_file_to_note_sequence(input_midi_1)
        input_2 = mm.midi_file_to_note_sequence(input_midi_2)

        def _check_extract_examples(input_ns, path, input_number):
            """Make sure each input returns exactly one example from the converter."""
            tensors = config.data_converter.to_tensors(input_ns).outputs
            if not tensors:
                print(
                    'MusicVAE configs have very specific input requirements. Could not '
                    'extract any valid inputs from `%s`. Try another MIDI file.'
                    % path)
                sys.exit()
            elif len(tensors) > 1:
                basename = os.path.join(
                    FLAGS.output_dir,
                    '%s_input%d-extractions_%s-*-of-%03d.mid' %
                    (FLAGS.config, input_number, date_and_time, len(tensors)))
                for i, ns in enumerate(
                        config.data_converter.to_notesequences(tensors)):
                    mm.sequence_proto_to_midi_file(
                        ns, basename.replace('*', '%03d' % i))
                print(
                    '%d valid inputs extracted from `%s`. Outputting these potential '
                    'inputs as `%s`. Call script again with one of these instead.'
                    % (len(tensors), path, basename))
                sys.exit()

        logging.info(
            'Attempting to extract examples from input MIDIs using config `%s`...',
            FLAGS.config)
        _check_extract_examples(input_1, FLAGS.input_midi_1, 1)
        _check_extract_examples(input_2, FLAGS.input_midi_2, 2)

    logging.info('Loading model...')
    if FLAGS.run_dir:
        checkpoint_dir_or_path = os.path.expanduser(
            os.path.join(FLAGS.run_dir, 'train'))
    else:
        checkpoint_dir_or_path = os.path.expanduser(FLAGS.checkpoint_file)
    model = TrainedModel(config,
                         batch_size=min(FLAGS.max_batch_size,
                                        FLAGS.num_outputs),
                         checkpoint_dir_or_path=checkpoint_dir_or_path)

    if FLAGS.mode == 'interpolate':
        logging.info('Interpolating...')
        _, mu, _ = model.encode([input_1, input_2])
        z = np.array([
            _slerp(mu[0], mu[1], t)
            for t in np.linspace(0, 1, FLAGS.num_outputs)
        ])
        results = model.decode(length=config.hparams.max_seq_len,
                               z=z,
                               temperature=FLAGS.temperature)
    elif FLAGS.mode == 'sample':
        logging.info('Sampling...')
        results = model.sample(n=FLAGS.num_outputs,
                               length=config.hparams.max_seq_len,
                               temperature=FLAGS.temperature)
    elif FLAGS.mode == 'isample':
        assert FLAGS.input_image is not None, 'Provide an image to sample from'
        assert FLAGS.input_midi_1 is not None, 'Provide a music to sample from'
        logging.info('Sampling z from image vae...')
        img = cv2.imread(FLAGS.input_image) / 255.
        img = np.asarray(cv2.resize(img, (320, 240)))
        img = np.expand_dims(img, axis=0)
        print(img.shape)
        latent = None
        input_midi_1 = os.path.expanduser(FLAGS.input_midi_1)
        input_1 = mm.midi_file_to_note_sequence(input_midi_1)

        def _check_extract_examples(input_ns, path, input_number):
            """Make sure each input returns exactly one example from the converter."""
            tensors = config.data_converter.to_tensors(input_ns).outputs
            if not tensors:
                print(
                    'MusicVAE configs have very specific input requirements. Could not '
                    'extract any valid inputs from `%s`. Try another MIDI file.'
                    % path)
                sys.exit()
            elif len(tensors) > 1:
                basename = os.path.join(
                    FLAGS.output_dir,
                    '%s_input%d-extractions_%s-*-of-%03d.mid' %
                    (FLAGS.config, input_number, date_and_time, len(tensors)))
                for i, ns in enumerate(
                        config.data_converter.to_notesequences(tensors)):
                    mm.sequence_proto_to_midi_file(
                        ns, basename.replace('*', '%03d' % i))
                print(
                    '%d valid inputs extracted from `%s`. Outputting these potential '
                    'inputs as `%s`. Call script again with one of these instead.'
                    % (len(tensors), path, basename))
                sys.exit()

        logging.info(
            'Attempting to extract examples from input MIDIs using config `%s`...',
            FLAGS.config)

        _check_extract_examples(input_1, FLAGS.input_midi_1, 1)

        with model._sess as sess:
            z_music, mu_music, sigma_music = model.encode([input_1])
            dataset = tf.data.Dataset.from_tensors(img.astype(np.float32))
            img = dataset.repeat().make_one_shot_iterator().get_next()
            mu, sigma = model.vae.encode(img, config.hparams)
            mu = mu.eval()
            sigma = sigma.eval()
            latent = ds.MultivariateNormalDiag(loc=mu + mu_music,
                                               scale_diag=sigma +
                                               sigma_music).sample().eval()
            results = model.decode(length=config.hparams.max_seq_len,
                                   z=latent,
                                   temperature=FLAGS.temperature)
            print(results)

    basename = os.path.join(
        FLAGS.output_dir, '%s_%s_%s-*-of-%03d.mid' %
        (FLAGS.config, FLAGS.mode, date_and_time, FLAGS.num_outputs))
    logging.info('Outputting %d files as `%s`...', FLAGS.num_outputs, basename)
    for i, ns in enumerate(results):
        mm.sequence_proto_to_midi_file(ns, basename.replace('*', '%03d' % i))

    logging.info('Done.')
def run(config_map):
  """Load model params, save config file and start trainer.

  Args:
    config_map: Dictionary mapping configuration name to Config object.

  Raises:
    ValueError: if required flags are missing or invalid.
  """
  date_and_time = time.strftime('%Y-%m-%d_%H%M%S')

  if FLAGS.run_dir is None == FLAGS.checkpoint_file is None:
    raise ValueError(
        'Exactly one of `--run_dir` or `--checkpoint_file` must be specified.')
  if FLAGS.output_dir is None:
    raise ValueError('`--output_dir` is required.')
  tf.gfile.MakeDirs(FLAGS.output_dir)
  if FLAGS.mode != 'sample' and FLAGS.mode != 'interpolate':
    raise ValueError('Invalid value for `--mode`: %s' % FLAGS.mode)

  if FLAGS.config not in config_map:
    raise ValueError('Invalid config name: %s' % FLAGS.config)
  config = config_map[FLAGS.config]
  config.data_converter.max_tensors_per_item = None

  if FLAGS.mode == 'interpolate':
    if FLAGS.input_midi_1 is None or FLAGS.input_midi_2 is None:
      raise ValueError(
          '`--input_midi_1` and `--input_midi_2` must be specified in '
          '`interpolate` mode.')
    input_midi_1 = os.path.expanduser(FLAGS.input_midi_1)
    input_midi_2 = os.path.expanduser(FLAGS.input_midi_2)
    if not os.path.exists(input_midi_1):
      raise ValueError('Input MIDI 1 not found: %s' % FLAGS.input_midi_1)
    if not os.path.exists(input_midi_2):
      raise ValueError('Input MIDI 2 not found: %s' % FLAGS.input_midi_2)
    input_1 = note_seq.midi_file_to_note_sequence(input_midi_1)
    input_2 = note_seq.midi_file_to_note_sequence(input_midi_2)

    def _check_extract_examples(input_ns, path, input_number):
      """Make sure each input returns exactly one example from the converter."""
      tensors = config.data_converter.to_tensors(input_ns).outputs
      if not tensors:
        print(
            'MusicVAE configs have very specific input requirements. Could not '
            'extract any valid inputs from `%s`. Try another MIDI file.' % path)
        sys.exit()
      elif len(tensors) > 1:
        basename = os.path.join(
            FLAGS.output_dir,
            '%s_input%d-extractions_%s-*-of-%03d.mid' %
            (FLAGS.config, input_number, date_and_time, len(tensors)))
        for i, ns in enumerate(config.data_converter.from_tensors(tensors)):
          note_seq.sequence_proto_to_midi_file(
              ns, basename.replace('*', '%03d' % i))
        print(
            '%d valid inputs extracted from `%s`. Outputting these potential '
            'inputs as `%s`. Call script again with one of these instead.' %
            (len(tensors), path, basename))
        sys.exit()
    logging.info(
        'Attempting to extract examples from input MIDIs using config `%s`...',
        FLAGS.config)
    _check_extract_examples(input_1, FLAGS.input_midi_1, 1)
    _check_extract_examples(input_2, FLAGS.input_midi_2, 2)

  logging.info('Loading model...')
  if FLAGS.run_dir:
    checkpoint_dir_or_path = os.path.expanduser(
        os.path.join(FLAGS.run_dir, 'train'))
  else:
    checkpoint_dir_or_path = os.path.expanduser(FLAGS.checkpoint_file)
  model = TrainedModel(
      config, batch_size=min(FLAGS.max_batch_size, FLAGS.num_outputs),
      checkpoint_dir_or_path=checkpoint_dir_or_path)

  if FLAGS.mode == 'interpolate':
    logging.info('Interpolating...')
    _, mu, _ = model.encode([input_1, input_2])
    z = np.array([
        _slerp(mu[0], mu[1], t) for t in np.linspace(0, 1, FLAGS.num_outputs)])
    results = model.decode(
        length=config.hparams.max_seq_len,
        z=z,
        temperature=FLAGS.temperature)
  elif FLAGS.mode == 'sample':
    logging.info('Sampling...')
    results = model.sample(
        n=FLAGS.num_outputs,
        length=config.hparams.max_seq_len,
        temperature=FLAGS.temperature)

  basename = os.path.join(
      FLAGS.output_dir,
      '%s_%s_%s-*-of-%03d.mid' %
      (FLAGS.config, FLAGS.mode, date_and_time, FLAGS.num_outputs))
  logging.info('Outputting %d files as `%s`...', FLAGS.num_outputs, basename)
  for i, ns in enumerate(results):
    note_seq.sequence_proto_to_midi_file(ns, basename.replace('*', '%03d' % i))

  logging.info('Done.')
예제 #4
0
input_4 = mm.midi_file_to_note_sequence(input_midi_4)
#_check_extract_examples(input_1, path_midi_1, 1)
#_check_extract_examples(input_2, path_midi_2, 2)

#_, mu, _ = model.encode([input_1, input_2])
z, mu, _ = model.encode([input_1, input_2, input_3, input_4])

# Get the new 'z' with the interpolation values
z_new = z[0] * float(arguments['<val1>']) + z[1] * float(
    arguments['<val2>']) + z[2] * float(arguments['<val3>']) + z[3] * float(
        arguments['<val4>'])
z_new_2 = z_new + 0.05
z_new = np.expand_dims(z_new, axis=0)
z_new_2 = np.expand_dims(z_new_2, axis=0)
result = model.decode(length=config.hparams.max_seq_len,
                      z=z_new,
                      temperature=temperature)
result_2 = model.decode(length=config.hparams.max_seq_len,
                        z=z_new_2,
                        temperature=temperature)
seq_result = mm.sequences_lib.concatenate_sequences([result[0], result_2[0]])

print(seq_result)

# OPTIONAL - Comment if you dont want to output a midifile
mm.sequence_proto_to_midi_file(
    seq_result,
    '/Users/prang/code/gitlab-acid/acids-live/mathieu_MusicVAE/MVAE_output/test.mid'
)

str_out = ''
예제 #5
0
def run(config_map):
  """Load model params, save config file and start trainer.

  Args:
    config_map: Dictionary mapping configuration name to Config object.

  Raises:
    ValueError: if required flags are missing or invalid.
  """
  date_and_time = time.strftime('%Y-%m-%d_%H%M%S')

  if FLAGS.run_dir is None == FLAGS.checkpoint_file is None:
    raise ValueError(
        'Exactly one of `--run_dir` or `--checkpoint_file` must be specified.')
  if FLAGS.output_dir is None:
    raise ValueError('`--output_dir` is required.')
  tf.gfile.MakeDirs(FLAGS.output_dir)
  if FLAGS.mode != 'sample' and FLAGS.mode != 'interpolate':
    raise ValueError('Invalid value for `--mode`: %s' % FLAGS.mode)

  if FLAGS.config not in config_map:
    raise ValueError('Invalid config name: %s' % FLAGS.config)
  config = config_map[FLAGS.config]
  config.data_converter.max_tensors_per_item = None

  if FLAGS.mode == 'interpolate':
    if FLAGS.input_midi_1 is None or FLAGS.input_midi_2 is None:
      raise ValueError(
          '`--input_midi_1` and `--input_midi_2` must be specified in '
          '`interpolate` mode.')
    input_midi_1 = os.path.expanduser(FLAGS.input_midi_1)
    input_midi_2 = os.path.expanduser(FLAGS.input_midi_2)
    if not os.path.exists(input_midi_1):
      raise ValueError('Input MIDI 1 not found: %s' % FLAGS.input_midi_1)
    if not os.path.exists(input_midi_2):
      raise ValueError('Input MIDI 2 not found: %s' % FLAGS.input_midi_2)
    input_1 = mm.midi_file_to_note_sequence(input_midi_1)
    input_2 = mm.midi_file_to_note_sequence(input_midi_2)

    def _check_extract_examples(input_ns, path, input_number):
      """Make sure each input returns exactly one example from the converter."""
      tensors = config.data_converter.to_tensors(input_ns).outputs
      if not tensors:
        print(
            'MusicVAE configs have very specific input requirements. Could not '
            'extract any valid inputs from `%s`. Try another MIDI file.' % path)
        sys.exit()
      elif len(tensors) > 1:
        basename = os.path.join(
            FLAGS.output_dir,
            '%s_input%d-extractions_%s-*-of-%03d.mid' %
            (FLAGS.config, input_number, date_and_time, len(tensors)))
        for i, ns in enumerate(config.data_converter.to_notesequences(tensors)):
          mm.sequence_proto_to_midi_file(ns, basename.replace('*', '%03d' % i))
        print(
            '%d valid inputs extracted from `%s`. Outputting these potential '
            'inputs as `%s`. Call script again with one of these instead.' %
            (len(tensors), path, basename))
        sys.exit()
    logging.info(
        'Attempting to extract examples from input MIDIs using config `%s`...',
        FLAGS.config)
    _check_extract_examples(input_1, FLAGS.input_midi_1, 1)
    _check_extract_examples(input_2, FLAGS.input_midi_2, 2)

  logging.info('Loading model...')
  if FLAGS.run_dir:
    checkpoint_dir_or_path = os.path.expanduser(
        os.path.join(FLAGS.run_dir, 'train'))
  else:
    checkpoint_dir_or_path = os.path.expanduser(FLAGS.checkpoint_file)
  model = TrainedModel(
      config, batch_size=min(FLAGS.max_batch_size, FLAGS.num_outputs),
      checkpoint_dir_or_path=checkpoint_dir_or_path)

  if FLAGS.mode == 'interpolate':
    logging.info('Interpolating...')
    _, mu, _ = model.encode([input_1, input_2])
    z = np.array([
        _slerp(mu[0], mu[1], t) for t in np.linspace(0, 1, FLAGS.num_outputs)])
    results = model.decode(
        length=config.hparams.max_seq_len,
        z=z,
        temperature=FLAGS.temperature)
  elif FLAGS.mode == 'sample':
    logging.info('Sampling...')
    results = model.sample(
        n=FLAGS.num_outputs,
        length=config.hparams.max_seq_len,
        temperature=FLAGS.temperature)

  basename = os.path.join(
      FLAGS.output_dir,
      '%s_%s_%s-*-of-%03d.mid' %
      (FLAGS.config, FLAGS.mode, date_and_time, FLAGS.num_outputs))
  logging.info('Outputting %d files as `%s`...', FLAGS.num_outputs, basename)
  for i, ns in enumerate(results):
    mm.sequence_proto_to_midi_file(ns, basename.replace('*', '%03d' % i))

  logging.info('Done.')
예제 #6
0
파일: MVAE_server.py 프로젝트: carsault/ICS
class MVAEServer(OSCServer):
    '''
    Key class for the Flow synthesizer server.

    Example :
    >>> server = FlowServer(1234, 1235) # Creating server
    >>> server.run() # Running server

    '''

    def __init__(self, *args, **kwargs):
        # Command-line arguments
        #self.args = kwargs.get('args')
        #self.model = kwargs.get('model')
        # Init model
        self.temperature = 1
        self._modelpath = "/Users/carsault/Dropbox/work/code/gitlab/cat-mel_2bar_big.tar"
        self.config_name = 'cat-mel_2bar_big'
        self.config = configs.CONFIG_MAP[self.config_name]
        self.config.data_converter.max_tensors_per_item = None
        checkpoint_dir_or_path = os.path.expanduser(self._modelpath)
        print('Loading model')
        self.model = TrainedModel(self.config, batch_size=1,checkpoint_dir_or_path=checkpoint_dir_or_path)
        # Init encoded files
        self.style_name = ['blues', 'classic', 'country', 'jazz', 'poprock','world', 'game', 'empty', 'RnB']
        input_files_list = ["./MVAE_input_valid/Blues1.mid",
                    "./MVAE_input_valid/classic1.mid",
                    "./MVAE_input_valid/country1.mid",
                    "./MVAE_input_valid/jazz1.mid",
                    "./MVAE_input_valid/poprock1.mid",
                    "./MVAE_input_valid/World1.mid",
                    "./MVAE_input_valid/game1.mid",
                    "./MVAE_input_valid/empty1.mid",
                    "./MVAE_input_valid/RnB1.mid"]
        self.input_z_list = []
        for file in input_files_list:
            input_midi = os.path.expanduser(file)
            inp = mm.midi_file_to_note_sequence(input_midi)
            z, mu, _ = self.model.encode([inp])
            self.input_z_list.append(z[0])
        # Init OSC server
        super(MVAEServer, self).__init__(*args)
        self.print('Server is ready.')

    def init_bindings(self, osc_attributes=[]):
        """ Set of OSC messages handled """
        super(MVAEServer, self).init_bindings(self.osc_attributes)
        # Send basic variables
        self.dispatcher.map('/decode', osc_parse(self.decode))
    
    def decode(self, id1, w1, id2, w2, id3, w3, id4, w4):
        z1 = self.input_z_list[self.style_name.index(id1)]
        z2 = self.input_z_list[self.style_name.index(id2)]
        z3 = self.input_z_list[self.style_name.index(id3)]
        z4 = self.input_z_list[self.style_name.index(id4)]
        z_new = z1*float(w1) + z2*float(w2) + z3*float(w3) + z4*float(w4)
        z_new_2 = z_new + 0.05
        z_new = np.expand_dims(z_new, axis=0)
        z_new_2 = np.expand_dims(z_new_2, axis=0)
        result = self.model.decode(length=self.config.hparams.max_seq_len,z=z_new,temperature=self.temperature)
        result_2 = self.model.decode(length=self.config.hparams.max_seq_len,z=z_new_2,temperature=self.temperature)
        seq_result = mm.sequences_lib.concatenate_sequences([result[0], result_2[0]])
        str_out = ''
        for note in seq_result.notes:
            str_out = str_out + ' ' + str(note.start_time/2) + ' ' + str(note.end_time/2 - note.start_time/2) + ' ' + str(note.pitch)
        str_out = str_out[1:]
        self.send('/decode', str_out)
    
    # Return current model state
    def get_state(self):
        """ Send set of properties of the current model """
        if self._model is not None:
            latent_dims = self._model.latent_dims
            regression_dims = self._model.regression_dims
        else:
            latent_dims = self._model.latent_dims
            regression_dims = self._model.regression_dims
        state = {'latent_dims': latent_dims,
                 'regression_dims': regression_dims}
        state_str = dict2str(state)
        self.send('/state', state_str)
        self.print('Server is ready.')
        return state
    
    """
    ###################
    Core functionalities (load, encode, decode)
    ###################
    """
            
    def load_preset(self, hash_v):
        """ Load a given preset based on its hash string """
        # Retrieve correct index
        l_idx = self.analysis['hash_loaders'][hash_v]
        cur_file = self.dataset[l_idx[0]].dataset.datadir + '/raw/' + self.dataset[l_idx[0]].dataset.data_files[l_idx[1]]
        loaded = np.load(cur_file, allow_pickle=True)
        params = loaded['param'].item()
        params = torch.Tensor([params[p] for p in self.param_names])
        out_list = []
        # Create dict out of params
        for p in range(params.shape[0]):
            out_list.append(self.param_names[p])
            out_list.append(float(params[p]))
        # Handle variables
        self.send('/params', out_list)
        cur_z = self.analysis['final_z'][l_idx[2]]
        if (self.freeze_mode):
            self.prev_z = torch.Tensor(1, cur_z.shape[0])
            self.prev_z[0] = cur_z
            print(self.prev_z[0])
        # Resend full z position
        out_list = []
        # Create dict out of params
        for p in range(cur_z.shape[0]):
            out_list.append('x%d'%(p))
            out_list.append(float(cur_z[self.analysis['d_idx'][p]]))
        # Handle variables
        self.send('/z_pos', out_list)
예제 #7
0
def run(config_map):
    '''
    Load model params, save config file and start trainer.

    Args:
        config_map: Dictionary mapping configuration name to Config object.

    Raises:
        ValueError: if required flags are missing or invalid.
    '''
    date_and_time = time.strftime('%Y-%m-%d_%H%M%S')

    if FLAGS.run_dir is None == FLAGS.checkpoint_file is None:
        raise ValueError(
            'Exactly one of `--run_dir` or `--checkpoint_file` must be specified.'
        )
    if FLAGS.output_dir is None:
        raise ValueError('`--output_dir` is required.')
    tf.gfile.MakeDirs(FLAGS.output_dir)

    if FLAGS.example_midi_dir is None:
        raise ValueError('example_midi_dir is required.')

    if FLAGS.config not in config_map:
        raise ValueError('Invalid config name: %s' % FLAGS.config)
    config = config_map[FLAGS.config]
    config.data_converter.max_tensors_per_item = None

    config_midime = MidiMeConfig
    if FLAGS.config_midime is not None:
        update(config_midime, FLAGS.config_midime)

    logging.info('Loading model...')
    if FLAGS.run_dir:
        checkpoint_dir_or_path = os.path.expanduser(
            os.path.join(FLAGS.run_dir, 'train'))
    else:
        checkpoint_dir_or_path = os.path.expanduser(FLAGS.checkpoint_file)
    model = TrainedModel(config,
                         batch_size=min(FLAGS.max_batch_size,
                                        FLAGS.num_outputs),
                         checkpoint_dir_or_path=checkpoint_dir_or_path)

    example_sequences = []

    example_midi_dir = os.path.expanduser(FLAGS.example_midi_dir)
    files_in_dir = tf.gfile.ListDirectory(os.path.join(example_midi_dir))
    for file_in_dir in files_in_dir:
        full_file_path = os.path.join(example_midi_dir, file_in_dir)
        try:
            example_sequences.append(
                mm.midi_file_to_note_sequence(full_file_path))
        except:
            raise ValueError('%s' % full_file_path)

    trimSilence(example_sequences)
    for i in example_sequences:
        i.tempos[0].time = 0
        del i.tempos[1:]

    chunks = getChunks(example_sequences, config.hparams.max_seq_len)

    latent = model.encode(chunks)[0]
    midime = train_model(latent, config_midime)

    s = midime.sample(FLAGS.num_outputs)
    samples = model.decode(s, config.hparams.max_seq_len)

    basename = os.path.join(
        FLAGS.output_dir, '%s_%s-*-of-%03d.mid' %
        (FLAGS.config, date_and_time, FLAGS.num_outputs))
    logging.info('Outputting %d files as `%s`...', FLAGS.num_outputs, basename)

    for i, ns in enumerate(samples):
        mm.sequence_proto_to_midi_file(ns, basename.replace('*', '%03d' % i))

    logging.info('Done.')
예제 #8
0
model = TrainedModel(config,
                     batch_size=min(FLAGS.max_batch_size, FLAGS.num_outputs),
                     checkpoint_dir_or_path=checkpoint_dir_or_path)
"""
Encodes a collection of NoteSequences into latent vectors.
    Args:
      note_sequences: A collection of NoteSequence objects to encode.
      assert_same_length: Whether to raise an AssertionError if all of the
        extracted sequences are not the same length.
    Returns:
      The encoded `z`, `mu`, and `sigma` values. (as tuple)
"""
logging.info('Encoding...')
#_, mu, _ = model.encode([input_1, input_2])
z, mu, sigma = model.encode([input_1])
#z = np.array([ # z = collection of latent vectors to decode
#    _slerp(mu[0], mu[1], t) for t in np.linspace(0, 1, FLAGS.num_outputs)]) #Spherical linear interpolation

results = model.decode(length=config.hparams.max_seq_len,
                       z=z,
                       temperature=FLAGS.temperature)

basename = os.path.join(
    FLAGS.output_dir, '%s_%s_%s-*-of-%03d.mid' %
    (FLAGS.config, FLAGS.mode, date_and_time, FLAGS.num_outputs))
logging.info('Outputting %d files as `%s`...', FLAGS.num_outputs, basename)
for i, ns in enumerate(results):
    mm.sequence_proto_to_midi_file(ns, basename.replace('*', '%03d' % i))

logging.info('Done.')