def get_data_loaders(direction, base_directory, fold_file, instrument_filename, context, audio_options, batch_size): print('-' * 30) print('getting data loaders:') print('direction', direction) print('base_directory', base_directory) print('fold_file', fold_file) print('instrument_filename', instrument_filename) clazz = Spec2MidiDataset datasets = get_dataset_individually(base_directory, fold_file, instrument_filename, context, audio_options, clazz) loaders = [] for dataset in datasets: audiofilename = dataset.audiofilename midifilename = dataset.midifilename dataset = SqueezingDataset(dataset) print('len(dataset)', len(dataset)) sampler = SequentialSampler(dataset) loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, drop_last=True) loaders.append((fold_file, audiofilename, midifilename, loader)) return loaders
def main(): parser = argparse.ArgumentParser() parser.add_argument('checkpoint') parser.add_argument('output_directory') args = parser.parse_args() device = 'cuda' if torch.cuda.is_available() else 'cpu' direction = 'spec2labels' print('direction', direction) n_epochs = 512 meta_epoch = 12 batch_size = 32 gamma = 0.96 model = ReversibleModel( device=device, batch_size=batch_size, depth=5, ndim_tot=256, ndim_x=144, ndim_y=185, ndim_z=9, clamp=2, zeros_noise_scale=3e-2, # very magic, much hack! y_noise_scale=3e-2) model.to(device) print('loading checkpoint') checkpoint = torch.load(args.checkpoint) model.load_state_dict(checkpoint) audio_options = dict( spectrogram_type='LogarithmicFilteredSpectrogram', filterbank='LogarithmicFilterbank', num_channels=1, sample_rate=44100, frame_size=4096, fft_size=4096, hop_size=441 * 4, # 25 fps num_bands=24, fmin=30, fmax=10000.0, fref=440.0, norm_filters=True, unique_filters=True, circular_shift=False, add=1.) context = dict(frame_size=1, hop_size=1, origin='center') print('loading data') base_directory = './data/maps_piano/data' fold_directory = './splits/maps-non-overlapping' utils.ensure_directory_exists(args.output_directory) for fold in ['train', 'valid', 'test']: fold_output_directory = os.path.join(args.output_directory, fold) if not os.path.exists(fold_output_directory): os.makedirs(fold_output_directory) print('fold', fold) print('fold_output_directory', fold_output_directory) sequences = get_dataset_individually( base_directory=base_directory, fold_filename=os.path.join(fold_directory, fold), instrument_filename=os.path.join(fold_directory, 'instruments'), context=context, audio_options=audio_options, clazz=Spec2MidiDataset) for sequence in sequences: print('sequence.audiofilename', sequence.audiofilename) print('sequence.midifilename', sequence.midifilename) output_filename = os.path.basename(sequence.audiofilename) output_filename = os.path.splitext(output_filename)[0] output_filename = os.path.join(fold_output_directory, output_filename + '.pkl') print('output_filename', output_filename) loader = DataLoader(SqueezingDataset(sequence), batch_size=batch_size, sampler=SequentialSampler(sequence), drop_last=True) result = export(device, model, loader) result['audiofilename'] = sequence.audiofilename result['midifilename'] = sequence.midifilename torch.save(result, output_filename)