Exemple #1
0
def preprocess_midi_files_under(midi_root, save_dir, num_workers):
    midi_paths = list(
        utils.find_files_by_extensions(midi_root, ['.mid', '.midi']))
    os.makedirs(save_dir, exist_ok=True)
    out_fmt = '{}-{}.data'

    results = []
    executor = ProcessPoolExecutor(num_workers)

    for path in midi_paths:
        try:
            results.append((path, executor.submit(preprocess_midi, path)))
        except KeyboardInterrupt:
            print(' Abort')
            return
        except:
            print(' Error')
            continue

    for path, future in Bar('Processing').iter(results):
        print(' ', end='[{}]'.format(path), flush=True)
        name = os.path.basename(path)
        code = hashlib.md5(path.encode()).hexdigest()
        save_path = os.path.join(save_dir, out_fmt.format(name, code))
        torch.save(future.result(), save_path)

    print('Done')
Exemple #2
0
def preprocess_midi_files_under(midi_root, save_dir):
    total_secs = []
    midi_paths = list(utils.find_files_by_extensions(midi_root, ['.mid', '.midi']))
    os.makedirs(save_dir, exist_ok=True)
    out_fmt = '{}-{}.data'
    i = 0
    for path in Bar('Processing').iter(midi_paths):
        print(' ', end='[{}]'.format(path), flush=True)

        try:
            data = preprocess_midi(path)
            total_secs.append(MidiFile(path).length)
        except KeyboardInterrupt:
            print('Abort')
            return
        except EOFError:
            print('EOF Error')
            continue
        except:
            print()
            print('Error in decoding: {}'.format(path))
            continue
            
        if len(data) > 0:
            i+= 1
            with open('{}/{}.pickle'.format(save_dir, path.split('/')[-1]), 'wb') as f:
                pickle.dump(data, f)
    print(f'Successfully Processed {i} files!') 
    print(f'Analyzed {len(total_secs)} files!')
    print('Dataset Descriptives in Minutes):')
    total_secs = np.array(total_secs)
    print(f'Sum: {total_secs.sum()/60.0}\tMean: {total_secs.mean()/60.0}\nMin: {total_secs.min()/60.0}\tMax: {total_secs.max()/60.0}\nStdev: {total_secs.std()/60.0}')
    def process_midi_from_dir(self, midi_root):
        """
        :param midi_root: midi 데이터가 저장되어있는 디렉터리 위치.
        :return:
        """

        midi_paths = list(
            utils.find_files_by_extensions(midi_root,
                                           ['.mid', '.midi', '.MID']))
        es_seq_list = []
        ctrl_seq_list = []
        for path in Bar('Processing').iter(midi_paths):
            print(' ', end='[{}]'.format(path), flush=True)

            try:
                data = preprocess_midi(path)
                for es_seq, ctrl_seq in data:
                    max_len = par.max_seq
                    for idx in range(max_len + 1):
                        es_seq_list.append(data[0])
                        ctrl_seq_list.append(data[1])

            except KeyboardInterrupt:
                print(' Abort')
                return
            except:
                print(' Error')
                continue

        return es_seq_list, ctrl_seq_list
def preprocess_midi_files_under(midi_root, save_dir):
    midi_paths = list(
        utils.find_files_by_extensions(midi_root, ['.mid', '.midi']))
    os.makedirs(save_dir, exist_ok=True)
    out_fmt = '{}-{}.data'

    for path in Bar('Processing').iter(midi_paths):
        print(' ', end='[{}]'.format(path), flush=True)

        try:
            data = preprocess_midi(path)
        except KeyboardInterrupt:
            print(' Abort')
            return
        except:
            print(' Error')
            continue

        with open('{}/{}.pickle'.format(save_dir,
                                        path.split('/')[-1]), 'wb') as f:
            pickle.dump(data[0], f)

        # name = os.path.basename(path)
        # code = hashlib.md5(path.encode()).hexdigest()
        # #save_path = os.path.join(save_dir, out_fmt.format(name, code))
        # #torch.save(data, save_path)

    print('Done')
 def __init__(self, dir_path):
     self.files = list(utils.find_files_by_extensions(dir_path, ['.pickle']))
     self.file_dict = {
         'train': self.files[:int(len(self.files) * 0.8)],
         'eval': self.files[int(len(self.files) * 0.8): int(len(self.files) * 0.9)],
         'test': self.files[int(len(self.files) * 0.9):],
     }
     self._seq_file_name_idx = 0
     self._seq_idx = 0
     pass
 def __init__(self, dir_path, splits=[0.8, 0.9]):
     self.files = list(utils.find_files_by_extensions(dir_path, ['.pickle']))
     # print(self.files)
     self.file_dict = {
         'train': self.files[:int(len(self.files) * splits[0])],
         'eval': self.files[int(len(self.files) * splits[0]): int(len(self.files) * splits[1])],
         'test': self.files[int(len(self.files) * splits[1]):],
     }
     self._seq_file_name_idx = 0
     self._seq_idx = 0
     pass
Exemple #7
0
 def __init__(self, root, verbose=False):
     assert os.path.isdir(root), root
     paths = utils.find_files_by_extensions(root, ['.data'])
     self.root = root
     self.samples = []
     self.seqlens = []
     if verbose:
         paths = Bar(root).iter(list(paths))
     for path in paths:
         eventseq, controlseq = torch.load(path)
         controlseq = ControlSeq.recover_compressed_array(controlseq)
         assert len(eventseq) == len(controlseq)
         self.samples.append((eventseq, controlseq))
         self.seqlens.append(len(eventseq))
     self.avglen = np.mean(self.seqlens)
Exemple #8
0
def preprocess_midi_files_under(midi_root, save_dir):
    midi_paths = list(utils.find_files_by_extensions(midi_root, ['.mid', '.midi']))
    os.makedirs(save_dir, exist_ok=True)
    out_fmt = '{}-{}.data'

    for path in tqdm(midi_paths, desc='MIDI Paths in {}'.format(midi_root)):
        # print(' ', end='[{}]'.format(path), flush=True)

        try:
            data = preprocess_midi(path)
        except KeyboardInterrupt:
            print(' Abort')
            return
        except EOFError:
            print('EOF Error')

        with open('{}/{}.pickle'.format(save_dir,path.split('/')[-1]), 'wb') as f:
            pickle.dump(data, f)
Exemple #9
0
def main(_):
    if os.path.isfile(FLAGS.control) or os.path.isdir(FLAGS.control):
        if os.path.isdir(FLAGS.control):
            files = list(utils.find_files_by_extensions(FLAGS.control))
            assert len(files) > 0, 'no file in "{control}"'.format(
                control=FLAGS.control)
            control = np.random.choice(files)
        events, compressed_controls = torch.load(FLAGS.control)
        controls = ControlSeq.recover_compressed_array(compressed_controls)
        max_len = FLAGS.max_length
        if FLAGS.max_length == 0:
            max_len = controls.shape[0]

        control = np.expand_dims(controls, 1).repeat(1, 1)
        control = 'control sequence from "{control}"'.format(control=control)

    assert max_len > 0, 'either max length or control sequence length should be given'

    #FLAGS.start_string = FLAGS.start_string.decode('utf-8')

    if os.path.isdir(FLAGS.checkpoint_path):
        FLAGS.checkpoint_path =\
            tf.train.latest_checkpoint(FLAGS.checkpoint_path)

    model = CharRNN(EventSeq.dim(),
                    ControlSeq.dim(),
                    sampling=True,
                    lstm_size=FLAGS.lstm_size,
                    num_layers=FLAGS.num_layers,
                    use_embedding=FLAGS.use_embedding,
                    embedding_size=FLAGS.embedding_size)
    model.sess.run(tf.global_variables_initializer())
    model.load(FLAGS.checkpoint_path)

    outputs = model.sample(1000,
                           prime=events[0:100],
                           vocab_size=EventSeq.dim())

    outputs = outputs.reshape([-1, 1])
    print(outputs)
    name = 'output-{i:03d}.mid'.format(i=0)
    path = os.path.join("output/", name)
    n_notes = utils.event_indeces_to_midi_file(outputs[:, 0], path)
    print('===> {path} ({n_notes} notes)'.format(path=path, n_notes=n_notes))
Exemple #10
0
def preprocess_midi_files_under(midi_root, save_dir):
    midi_paths = list(utils.find_files_by_extensions(midi_root, ['.mid', '.midi']))
    os.makedirs(save_dir, exist_ok = True)
    out_fmt = '{}.data'
    
    for path in Bar('Processing').iter(midi_paths) :
        print(' ', end='[{}]'.format(path), flush=True)
        try :
            data = preprocess_midi(path)
        except KeyboardInterrupt :
            print(' Abort')
            return
        except :
            print(' Error')
            continue
        name = os.path.basename(path)
        name = name.split('.')
        name = name[0]
        code = hashlib.md5(path.encode()).hexdigest()
        save_path = os.path.join(save_dir, out_fmt.format(name))
        torch.save(data, save_path)
    print('Done')
def preprocess_midi_files_under(midi_root, save_dir):
    midi_paths = list(
        utils.find_files_by_extensions(midi_root, ['.mid', '.midi']))
    os.makedirs(save_dir, exist_ok=True)
    out_fmt = '{}-{}.data'

    for path in Bar('Processing').iter(midi_paths):
        print(' ', end='[{}]'.format(path), flush=True)

        output_path = '{}/{}.pickle'.format(save_dir, path.split('/')[-1])
        if os.path.isfile(output_path):
            continue

        try:
            data = preprocess_midi(path)
        except KeyboardInterrupt:
            print(' Abort')
            return
        except EOFError:
            print('EOF Error')

        with open(output_path, 'wb') as f:
            pickle.dump(data, f)
use_beam_search = opt.beam_size > 0
beam_size = opt.beam_size
temperature = opt.temperature
init_zero = opt.init_zero

if use_beam_search:
    greedy_ratio = 'DISABLED'
else:
    beam_size = 'DISABLED'

assert os.path.isfile(sess_path), f'"{sess_path}" is not a file'

if control is not None:
    if os.path.isfile(control) or os.path.isdir(control):
        if os.path.isdir(control):
            files = list(utils.find_files_by_extensions(control))
            assert len(files) > 0, f'no file in "{control}"'
            control = np.random.choice(files)
        _, compressed_controls = torch.load(control)
        controls = ControlSeq.recover_compressed_array(compressed_controls)
        if max_len == 0:
            max_len = controls.shape[0]
        controls = torch.tensor(controls, dtype=torch.float32)
        controls = controls.unsqueeze(1).repeat(1, batch_size, 1).to(device)
        control = f'control sequence from "{control}"'

    else:
        pitch_histogram, note_density = control.split(';')
        pitch_histogram = list(filter(len, pitch_histogram.split(',')))
        if len(pitch_histogram) == 0:
            pitch_histogram = np.ones(12) / 12
Exemple #13
0
 def __init__(self, dir_path):
     self.files = list(utils.find_files_by_extensions(
         dir_path, ['.pickle']))
     self._seq_file_name_idx = 0
     self._seq_idx = 0
     pass
Exemple #14
0
    "An item is a file path or a 'key=value' formatted string. "
    "The type of a value is determined by applying int(), float(), and str() "
    "to it sequencially.")
args = parser.parse_args()
config.load(args.model_dir, args.configs, initialize=True)

condition_file = None
if args.condition_file is "" or args.condition_file is None:
    condition_file = None
else:
    if os.path.exists(args.condition_file):
        condition_file = args.condition_file
    else:
        print("Partial searching for ", args.condition_file)
        for midi_path in midi_lib:
            for full_fn in find_files_by_extensions(midi_path,
                                                    ['.mid', '.midi']):
                if args.condition_file in full_fn:
                    condition_file = full_fn
    condition_fn = condition_file.split('\\')[-1].split('.')[0]
# check cuda
if torch.cuda.is_available():
    config.device = torch.device('cuda')
else:
    config.device = torch.device('cpu')

current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
gen_log_dir = 'logs/mt_decoder/generate_' + current_time + '/generate'
gen_summary_writer = SummaryWriter(gen_log_dir)

mt = MusicTransformer(embedding_dim=config.embedding_dim,
                      vocab_size=config.vocab_size,