コード例 #1
0
                                 max_length_out, num_batches)
    # NOTE: must set batch_size=1 here.
    train_loader = AudioDataLoader(train_dataset,
                                   batch_size=1,
                                   num_workers=num_workers)

    # MODEL PART
    input_size = 83
    hidden_size = 2
    num_layers = 2
    bidirectional = True
    rnn_type = 'lstm'

    encoder = Encoder(input_size,
                      hidden_size,
                      num_layers,
                      bidirectional=bidirectional,
                      rnn_type=rnn_type)
    encoder.cuda()
    for i, (data) in enumerate(train_loader):
        padded_input, input_lengths, targets = data
        padded_input = padded_input.cuda()
        input_lengths = input_lengths.cuda()
        print(i)
        print(padded_input.size())
        print(input_lengths.size())
        output, hidden = encoder(padded_input, input_lengths)
        print(output)
        print(output.size())
        print("*" * 20)
def main(config):
    print('Starting')

    checkpoints = config.checkpoint.parent.glob(config.checkpoint.name +
                                                '_*.pth')
    checkpoints = [c for c in checkpoints if extract_id(c) in config.decoders]
    assert len(checkpoints) >= 1, "No checkpoints found."

    model_config = torch.load(config.checkpoint.parent / 'args.pth')[0]
    encoder = Encoder(model_config.encoder)
    encoder.load_state_dict(torch.load(checkpoints[0])['encoder_state'])
    encoder.eval()
    encoder = encoder.cuda()

    generators = []
    generator_ids = []
    for checkpoint in checkpoints:
        decoder = Decoder(model_config.decoder)
        decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])
        decoder.eval()
        decoder = decoder.cuda()

        generator = SampleGenerator(decoder,
                                    config.batch_size,
                                    wav_freq=config.rate)

        generators.append(generator)
        generator_ids.append(extract_id(checkpoint))

    xs = []
    assert config.out_dir is not None

    if len(config.sample_dir) == 1 and config.sample_dir[0].is_dir():
        top = config.sample_dir[0]
        file_paths = list(top.glob('**/*.wav')) + list(top.glob('**/*.h5'))
    else:
        file_paths = config.sample_dir

    print("File paths to be used:", file_paths)
    for file_path in file_paths:
        if file_path.suffix == '.wav':
            data, rate = librosa.load(file_path, sr=config.rate)
            data = helper_functions.mu_law(data)
        elif file_path.suffix == '.h5':
            data = helper_functions.mu_law(
                h5py.File(file_path, 'r')['wav'][:] / (2**15))
            if data.shape[-1] % config.rate != 0:
                data = data[:-(data.shape[-1] % config.rate)]
            assert data.shape[-1] % config.rate == 0
            print(data.shape)
        else:
            raise Exception(f'Unsupported filetype {file_path}')

        if config.sample_len:
            data = data[:config.sample_len]
        else:
            config.sample_len = len(data)
        xs.append(torch.tensor(data).unsqueeze(0).float().cuda())

    xs = torch.stack(xs).contiguous()
    print(f'xs size: {xs.size()}')

    def save(x, decoder_idx, filepath):
        wav = helper_functions.inv_mu_law(x.cpu().numpy())
        print(f'X size: {x.shape}')
        print(f'X min: {x.min()}, max: {x.max()}')

        save_audio(wav.squeeze(),
                   config.out_dir / str(decoder_idx) /
                   filepath.with_suffix('.wav').name,
                   rate=config.rate)

    yy = {}
    with torch.no_grad():
        zz = []
        for xs_batch in torch.split(xs, config.batch_size):
            zz += [encoder(xs_batch)]
        zz = torch.cat(zz, dim=0)

        for i, generator_id in enumerate(generator_ids):
            yy[generator_id] = []
            generator = generators[i]
            for zz_batch in torch.split(zz, config.batch_size):
                print("Batch shape:", zz_batch.shape)
                splits = torch.split(zz_batch, config.split_size, -1)
                audio_data = []
                generator.reset()
                for cond in tqdm.tqdm(splits):
                    audio_data += [generator.generate(cond).cpu()]
                audio_data = torch.cat(audio_data, -1)
                yy[generator_id] += [audio_data]
            yy[generator_id] = torch.cat(yy[generator_id], dim=0)

            for sample_result, filepath in zip(yy[generator_id], file_paths):
                save(sample_result, generator_id, filepath)

            del generator