max_length_out, num_batches) # NOTE: must set batch_size=1 here. train_loader = AudioDataLoader(train_dataset, batch_size=1, num_workers=num_workers) # MODEL PART input_size = 83 hidden_size = 2 num_layers = 2 bidirectional = True rnn_type = 'lstm' encoder = Encoder(input_size, hidden_size, num_layers, bidirectional=bidirectional, rnn_type=rnn_type) encoder.cuda() for i, (data) in enumerate(train_loader): padded_input, input_lengths, targets = data padded_input = padded_input.cuda() input_lengths = input_lengths.cuda() print(i) print(padded_input.size()) print(input_lengths.size()) output, hidden = encoder(padded_input, input_lengths) print(output) print(output.size()) print("*" * 20)
def main(config): print('Starting') checkpoints = config.checkpoint.parent.glob(config.checkpoint.name + '_*.pth') checkpoints = [c for c in checkpoints if extract_id(c) in config.decoders] assert len(checkpoints) >= 1, "No checkpoints found." model_config = torch.load(config.checkpoint.parent / 'args.pth')[0] encoder = Encoder(model_config.encoder) encoder.load_state_dict(torch.load(checkpoints[0])['encoder_state']) encoder.eval() encoder = encoder.cuda() generators = [] generator_ids = [] for checkpoint in checkpoints: decoder = Decoder(model_config.decoder) decoder.load_state_dict(torch.load(checkpoint)['decoder_state']) decoder.eval() decoder = decoder.cuda() generator = SampleGenerator(decoder, config.batch_size, wav_freq=config.rate) generators.append(generator) generator_ids.append(extract_id(checkpoint)) xs = [] assert config.out_dir is not None if len(config.sample_dir) == 1 and config.sample_dir[0].is_dir(): top = config.sample_dir[0] file_paths = list(top.glob('**/*.wav')) + list(top.glob('**/*.h5')) else: file_paths = config.sample_dir print("File paths to be used:", file_paths) for file_path in file_paths: if file_path.suffix == '.wav': data, rate = librosa.load(file_path, sr=config.rate) data = helper_functions.mu_law(data) elif file_path.suffix == '.h5': data = helper_functions.mu_law( h5py.File(file_path, 'r')['wav'][:] / (2**15)) if data.shape[-1] % config.rate != 0: data = data[:-(data.shape[-1] % config.rate)] assert data.shape[-1] % config.rate == 0 print(data.shape) else: raise Exception(f'Unsupported filetype {file_path}') if config.sample_len: data = data[:config.sample_len] else: config.sample_len = len(data) xs.append(torch.tensor(data).unsqueeze(0).float().cuda()) xs = torch.stack(xs).contiguous() print(f'xs size: {xs.size()}') def save(x, decoder_idx, filepath): wav = helper_functions.inv_mu_law(x.cpu().numpy()) print(f'X size: {x.shape}') print(f'X min: {x.min()}, max: {x.max()}') save_audio(wav.squeeze(), config.out_dir / str(decoder_idx) / filepath.with_suffix('.wav').name, rate=config.rate) yy = {} with torch.no_grad(): zz = [] for xs_batch in torch.split(xs, config.batch_size): zz += [encoder(xs_batch)] zz = torch.cat(zz, dim=0) for i, generator_id in enumerate(generator_ids): yy[generator_id] = [] generator = generators[i] for zz_batch in torch.split(zz, config.batch_size): print("Batch shape:", zz_batch.shape) splits = torch.split(zz_batch, config.split_size, -1) audio_data = [] generator.reset() for cond in tqdm.tqdm(splits): audio_data += [generator.generate(cond).cpu()] audio_data = torch.cat(audio_data, -1) yy[generator_id] += [audio_data] yy[generator_id] = torch.cat(yy[generator_id], dim=0) for sample_result, filepath in zip(yy[generator_id], file_paths): save(sample_result, generator_id, filepath) del generator