Пример #1
0
def get_data_loaders(direction, base_directory, fold_file, instrument_filename,
                     context, audio_options, batch_size):

    print('-' * 30)
    print('getting data loaders:')
    print('direction', direction)
    print('base_directory', base_directory)
    print('fold_file', fold_file)
    print('instrument_filename', instrument_filename)

    clazz = Spec2MidiDataset

    datasets = get_dataset_individually(base_directory, fold_file,
                                        instrument_filename, context,
                                        audio_options, clazz)
    loaders = []
    for dataset in datasets:
        audiofilename = dataset.audiofilename
        midifilename = dataset.midifilename
        dataset = SqueezingDataset(dataset)
        print('len(dataset)', len(dataset))

        sampler = SequentialSampler(dataset)

        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            sampler=sampler,
                            drop_last=True)
        loaders.append((fold_file, audiofilename, midifilename, loader))

    return loaders
Пример #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('checkpoint')
    parser.add_argument('output_directory')
    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    direction = 'spec2labels'
    print('direction', direction)

    n_epochs = 512
    meta_epoch = 12
    batch_size = 32
    gamma = 0.96

    model = ReversibleModel(
        device=device,
        batch_size=batch_size,
        depth=5,
        ndim_tot=256,
        ndim_x=144,
        ndim_y=185,
        ndim_z=9,
        clamp=2,
        zeros_noise_scale=3e-2,  # very magic, much hack!
        y_noise_scale=3e-2)
    model.to(device)

    print('loading checkpoint')
    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint)

    audio_options = dict(
        spectrogram_type='LogarithmicFilteredSpectrogram',
        filterbank='LogarithmicFilterbank',
        num_channels=1,
        sample_rate=44100,
        frame_size=4096,
        fft_size=4096,
        hop_size=441 * 4,  # 25 fps
        num_bands=24,
        fmin=30,
        fmax=10000.0,
        fref=440.0,
        norm_filters=True,
        unique_filters=True,
        circular_shift=False,
        add=1.)
    context = dict(frame_size=1, hop_size=1, origin='center')

    print('loading data')
    base_directory = './data/maps_piano/data'
    fold_directory = './splits/maps-non-overlapping'

    utils.ensure_directory_exists(args.output_directory)

    for fold in ['train', 'valid', 'test']:
        fold_output_directory = os.path.join(args.output_directory, fold)
        if not os.path.exists(fold_output_directory):
            os.makedirs(fold_output_directory)

        print('fold', fold)
        print('fold_output_directory', fold_output_directory)

        sequences = get_dataset_individually(
            base_directory=base_directory,
            fold_filename=os.path.join(fold_directory, fold),
            instrument_filename=os.path.join(fold_directory, 'instruments'),
            context=context,
            audio_options=audio_options,
            clazz=Spec2MidiDataset)

        for sequence in sequences:
            print('sequence.audiofilename', sequence.audiofilename)
            print('sequence.midifilename', sequence.midifilename)
            output_filename = os.path.basename(sequence.audiofilename)
            output_filename = os.path.splitext(output_filename)[0]
            output_filename = os.path.join(fold_output_directory,
                                           output_filename + '.pkl')

            print('output_filename', output_filename)

            loader = DataLoader(SqueezingDataset(sequence),
                                batch_size=batch_size,
                                sampler=SequentialSampler(sequence),
                                drop_last=True)

            result = export(device, model, loader)
            result['audiofilename'] = sequence.audiofilename
            result['midifilename'] = sequence.midifilename
            torch.save(result, output_filename)