def get_voice_file(idx, duration, quantize_type): """ Gets one of the last VCTK voices """ BASE_PATH = "/projects/grail/audiovisual/datasets/VCTK-Corpus/wav48/test" assert (idx in list(range(0, 100))) if idx < 25: speaker_path = os.path.join(BASE_PATH, "p345") elif idx < 50: speaker_path = os.path.join(BASE_PATH, "p361") elif idx < 75: speaker_path = os.path.join(BASE_PATH, "p362") elif idx < 100: speaker_path = os.path.join(BASE_PATH, "p374") file_list = list(Path(speaker_path).rglob('*.wav')) curr_file = random.choice(file_list) y, sr = librosa.core.load(curr_file, sr=22050) y /= abs(y).max() start_idx = len(y) // 2 y = y[int(start_idx - duration / 2):int(start_idx + duration / 2)] # Mulaw, linear or linear max audio if quantize_type == 0: quantized = P.mulaw_quantize(y, hparams.quantize_channels - 1) elif quantize_type == 1: quantized = linear_quantize(y, hparams.quantize_channels - 1) return quantized
def get_piano_file(idx, duration, quantize_type): """ Gets one of the test supra piano samples """ BASE_PATH = "/projects/grail/audiovisual/datasets/supra-rw-mp3/test" file_list = list(Path(BASE_PATH).rglob("*.mp3")) curr_file = random.choice(file_list) y, sr = librosa.core.load(curr_file, sr=22050) y /= abs(y).max() num_samples = y.shape[0] start_idx = random.randint(0, num_samples - duration) y = y[start_idx:start_idx + duration] # Mulaw, linear or linear max audio if quantize_type == 0: quantized = P.mulaw_quantize(y, hparams.quantize_channels - 1) elif quantize_type == 1: quantized = linear_quantize(y, hparams.quantize_channels - 1) return quantized
def main(args): model0 = ModelWrapper() model1 = ModelWrapper() receptive_field = model0.receptive_field writing_dir = args["<output-dir>"] os.makedirs(writing_dir, exist_ok=True) print("writing dir: {}".format(writing_dir)) source1 = librosa.core.load(args["<input-file1>"], sr=22050, mono=True)[0] source2 = librosa.core.load(args["<input-file2>"], sr=22050, mono=True)[0] mixed = source1 + source2 # Increase the volume of the mixture fo avoid artifacts from linear encoding mixed /= abs(mixed).max() mixed *= 1.4 mixed = linear_quantize(mixed + 1.0, hparams.quantize_channels - 1) global SAMPLE_SIZE if SAMPLE_SIZE == -1: SAMPLE_SIZE = int(mixed.shape[0]) mixed = mixed[:SAMPLE_SIZE] mixed = torch.FloatTensor(mixed).reshape(1, -1).to(device) # Write inputs mixed_out = inv_linear_quantize(mixed[0].detach().cpu().numpy(), hparams.quantize_channels - 1) - 1.0 mixed_out = np.clip(mixed_out, -1, 1) sf.write(join(writing_dir, "mixed.wav"), mixed_out, hparams.sample_rate) # Initialize with noise x0 = torch.FloatTensor(np.random.uniform(0, 512, size=(1, SAMPLE_SIZE))).to(device) x0[:] = mixed - 127.0 x0 = F.pad(x0, (receptive_field, receptive_field), "constant", 127) x0.requires_grad = True x1 = torch.FloatTensor(np.random.uniform(0, 512, size=(1, SAMPLE_SIZE))).to(device) x1[:] = 127. x1 = F.pad(x1, (receptive_field, receptive_field), "constant", 127) x1.requires_grad = True sigmas = [ 175.9, 110., 68.7, 54.3, 42.9, 34.0, 26.8, 21.2, 16.8, 13.3, 10.5, 8.29, 6.55, 5.18, 4.1, 3.24, 2.56, 1.6, 1.0, 0.625, 0.39, 0.244, 0.15, 0.1 ] np.random.seed(999) for idx, sigma in enumerate(sigmas): # We make sure each sample is updated a certain number of times n_steps = int((SAMPLE_SIZE / (SGLD_WINDOW * BATCH_SIZE)) * N_STEPS) print("Number of SGLD steps {}".format(n_steps)) # Bump down a model checkpoint_path0 = join(args["<checkpoint0>"], CHECKPOINTS[sigma], "checkpoint_latest_ema.pth") model0.load_checkpoint(checkpoint_path0) checkpoint_path1 = join(args["<checkpoint1>"], CHECKPOINTS[sigma], "checkpoint_latest_ema.pth") model1.load_checkpoint(checkpoint_path1) parmodel0 = torch.nn.DataParallel(model0) parmodel0.to(device) parmodel1 = torch.nn.DataParallel(model1) parmodel1.to(device) eta = .05 * (sigma**2) gamma = 15 * (1.0 / sigma)**2 t0 = time.time() for i in range(n_steps): # need to get a good sampling of the beginning/end (boundary effects) # to understand this: think about how often we would update x[receptive_field] (first point) # if we only sampled U(receptive_field,x0.shape-receptive_field-SGLD_WINDOW) j = np.random.randint(receptive_field - SGLD_WINDOW, x0.shape[1] - receptive_field, BATCH_SIZE) j = np.maximum(j, receptive_field) j = np.minimum(j, x0.shape[1] - (SGLD_WINDOW + receptive_field)) # Seed with noised up silence x0[0, :receptive_field] = torch.FloatTensor( np.random.normal(127, sigma, mixed[0, :receptive_field].shape)).to(device) x0[0, -receptive_field:] = torch.FloatTensor( np.random.normal(127, sigma, mixed[0, -receptive_field:].shape)).to(device) x1[0, :receptive_field] = torch.FloatTensor( np.random.normal(127, sigma, mixed[0, :receptive_field].shape)).to(device) x1[0, -receptive_field:] = torch.FloatTensor( np.random.normal(127, sigma, mixed[0, -receptive_field:].shape)).to(device) patches0 = [] patches1 = [] mixpatch = [] for k in range(BATCH_SIZE): patches0.append(x0[:, j[k] - receptive_field:j[k] + SGLD_WINDOW + receptive_field]) patches1.append(x1[:, j[k] - receptive_field:j[k] + SGLD_WINDOW + receptive_field]) mixpatch.append(mixed[:, j[k] - receptive_field:j[k] - receptive_field + SGLD_WINDOW]) patches0 = torch.stack(patches0, axis=0) patches1 = torch.stack(patches1, axis=0) mixpatch = torch.stack(mixpatch, axis=0) # Forward pass log_prob, prediction0 = parmodel0(patches0, sigma=sigma) log_prob0 = torch.sum(log_prob) grad0 = torch.autograd.grad(log_prob0, x0)[0] log_prob, prediction1 = parmodel1(patches1, sigma=sigma) log_prob1 = torch.sum(log_prob) grad1 = torch.autograd.grad(log_prob1, x1)[0] x0_update, x1_update = [], [] for k in range(BATCH_SIZE): x0_update.append(eta * grad0[:, j[k]:j[k] + SGLD_WINDOW]) x1_update.append(eta * grad1[:, j[k]:j[k] + SGLD_WINDOW]) # Langevin step for k in range(BATCH_SIZE): epsilon0 = np.sqrt(2 * eta) * torch.normal( 0, 1, size=(1, SGLD_WINDOW), device=device) x0_update[k] += epsilon0 epsilon1 = np.sqrt(2 * eta) * torch.normal( 0, 1, size=(1, SGLD_WINDOW), device=device) x1_update[k] += epsilon1 # Reconstruction step for k in range(BATCH_SIZE): x0_update[k] -= eta * gamma * ( patches0[k][:, receptive_field:receptive_field + SGLD_WINDOW] + patches1[k][:, receptive_field:receptive_field + SGLD_WINDOW] - mixpatch[k]) x1_update[k] -= eta * gamma * ( patches0[k][:, receptive_field:receptive_field + SGLD_WINDOW] + patches1[k][:, receptive_field:receptive_field + SGLD_WINDOW] - mixpatch[k]) with torch.no_grad(): for k in range(BATCH_SIZE): x0[:, j[k]:j[k] + SGLD_WINDOW] += x0_update[k] x1[:, j[k]:j[k] + SGLD_WINDOW] += x1_update[k] if (not i % 40) or (i == (n_steps - 1)): # debugging print("--------------") print('sigma = {}'.format(sigma)) print('eta = {}'.format(eta)) print("i {}".format(i)) print("Max sample {}".format(abs(x0).max())) print('Mean sample logpx: {}'.format( log_prob0 / (BATCH_SIZE * SGLD_WINDOW))) print('Mean sample logpy: {}'.format( log_prob1 / (BATCH_SIZE * SGLD_WINDOW))) print("Max gradient update: {}".format(eta * abs(grad0).max())) print("Reconstruction: {}".format( abs(x0[:, receptive_field:-receptive_field] + x1[:, receptive_field:-receptive_field] - mixed).mean())) print('Elapsed time = {}'.format(time.time() - t0)) t0 = time.time() out0 = inv_linear_quantize(x0[0].detach().cpu().numpy(), hparams.quantize_channels - 1) out0 = np.clip(out0, -1, 1) sf.write(join(writing_dir, "out0_{}.wav".format(sigma)), out0, hparams.sample_rate) out1 = inv_linear_quantize(x1[0].detach().cpu().numpy(), hparams.quantize_channels - 1) out1 = np.clip(out1, -1, 1) sf.write(join(writing_dir, "out1_{}.wav".format(sigma)), out1, hparams.sample_rate)
def _process_utterance(out_dir, index, wav_path, text, no_mel): # Load the audio to a numpy array: wav = audio.load_wav(wav_path) # Trim begin/end silences # NOTE: the threshold was chosen for clean signals wav, _ = librosa.effects.trim(wav, top_db=60, frame_length=2048, hop_length=512) if hparams.highpass_cutoff > 0.0: wav = audio.low_cut_filter(wav, hparams.sample_rate, hparams.highpass_cutoff) # Mu-law quantize if is_mulaw_quantize(hparams.input_type): # Trim silences in mul-aw quantized domain silence_threshold = 0 if silence_threshold > 0: # [0, quantize_channels) out = P.mulaw_quantize(wav, hparams.quantize_channels - 1) start, end = audio.start_and_end_indices(out, silence_threshold) wav = wav[start:end] constant_values = P.mulaw_quantize(0, hparams.quantize_channels - 1) out_dtype = np.int16 elif is_linear_quantize(hparams.input_type): # Trim silences in linear quantized domain silence_threshold = 0 if silence_threshold > 0: # [0, quantize_channels) out = linear_quantize(wav, hparams.quantize_channels - 1) start, end = audio.start_and_end_indices(out, silence_threshold) wav = wav[start:end] constant_values = linear_quantize(0, hparams.quantize_channels - 1) out_dtype = np.int16 elif is_mulaw(hparams.input_type): # [-1, 1] constant_values = P.mulaw(0.0, hparams.quantize_channels - 1) out_dtype = np.float32 else: # [-1, 1] constant_values = 0.0 out_dtype = np.float32 if hparams.global_gain_scale > 0: wav *= hparams.global_gain_scale if hparams.normalize_max_audio: wav /= abs(wav).max() # Compute a mel-scale spectrogram from the trimmed wav: # (N, D) if not no_mel: mel_spectrogram = audio.logmelspectrogram(wav).astype(np.float32).T # Time domain preprocessing if hparams.preprocess is not None and hparams.preprocess not in [ "", "none" ]: f = getattr(audio, hparams.preprocess) wav = f(wav) # Clip if np.abs(wav).max() > 1.0: print("""Warning: abs max value exceeds 1.0: {}""".format( np.abs(wav).max())) # ignore this sample # return ("dummy", "dummy", -1, "dummy") wav = np.clip(wav, -1.0, 1.0) # Set waveform target (out) if is_mulaw_quantize(hparams.input_type): out = P.mulaw_quantize(wav, hparams.quantize_channels - 1) elif is_linear_quantize(hparams.input_type): out = linear_quantize(wav, hparams.quantize_channels - 1) elif is_mulaw(hparams.input_type): out = P.mulaw(wav, hparams.quantize_channels - 1) else: out = wav # zero pad # this is needed to adjust time resolution between audio and mel-spectrogram if not no_mel: l, r = audio.pad_lr(out, hparams.fft_size, audio.get_hop_size()) if l > 0 or r > 0: out = np.pad(out, (l, r), mode="constant", constant_values=constant_values) N = mel_spectrogram.shape[0] assert len(out) >= N * audio.get_hop_size() # time resolution adjustment # ensure length of raw audio is multiple of hop_size so that we can use # transposed convolution to upsample out = out[:N * audio.get_hop_size()] assert len(out) % audio.get_hop_size() == 0 # Write the spectrograms to disk: name = splitext(basename(wav_path))[0] audio_filename = '{}-{}-wave.npy'.format(name, index) np.save(os.path.join(out_dir, audio_filename), out.astype(out_dtype), allow_pickle=False) if not no_mel: mel_filename = '{}-{}-feats.npy'.format(name, index) np.save(os.path.join(out_dir, mel_filename), mel_spectrogram.astype(np.float32), allow_pickle=False) else: mel_filename = "" # Return a tuple describing this training example: return (audio_filename, mel_filename, N, text)
def eval_model(global_step, writer, device, model, y, c, g, input_lengths, eval_dir, ema=None): if ema is not None: print("Using averaged model for evaluation") model = clone_as_averaged_model(device, model, ema) model.make_generation_fast_() model.eval() idx = np.random.randint(0, len(y)) length = input_lengths[idx].data.cpu().item() # (T,) y_target = y[idx].view(-1).data.cpu().numpy()[:length] if c is not None: if hparams.upsample_conditional_features: c = c[idx, :, :length // audio.get_hop_size() + hparams.cin_pad * 2].unsqueeze(0) else: c = c[idx, :, :length].unsqueeze(0) assert c.dim() == 3 print("Shape of local conditioning features: {}".format(c.size())) if g is not None: # TODO: test g = g[idx] print("Shape of global conditioning features: {}".format(g.size())) # Dummy silence if is_mulaw_quantize(hparams.input_type): initial_value = P.mulaw_quantize(0, hparams.quantize_channels - 1) elif is_linear_quantize(hparams.input_type): initial_value = linear_quantize(0, hparams.quantize_channels - 1) elif is_mulaw(hparams.input_type): initial_value = P.mulaw(0.0, hparams.quantize_channels) else: initial_value = 0.0 # (C,) if (is_mulaw_quantize(hparams.input_type) or is_linear_quantize( hparams.input_type)) and not hparams.manual_scalar_input: initial_input = to_categorical( initial_value, num_classes=hparams.quantize_channels).astype(np.float32) initial_input = torch.from_numpy(initial_input).view( 1, 1, hparams.quantize_channels) else: initial_input = torch.zeros(1, 1, 1).fill_(initial_value) initial_input = initial_input.to(device) # Run the model in fast eval mode with torch.no_grad(): y_hat = model.incremental_forward(initial_input, c=c, g=g, T=length, softmax=True, quantize=True, tqdm=tqdm, log_scale_min=hparams.log_scale_min) if is_mulaw_quantize(hparams.input_type): y_hat = y_hat.max(1)[1].view(-1).long().cpu().data.numpy() y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels - 1) y_target = P.inv_mulaw_quantize(y_target, hparams.quantize_channels - 1) elif is_linear_quantize(hparams.input_type): y_hat = y_hat.max(1)[1].view(-1).long().cpu().data.numpy() y_hat = inv_linear_quantize(y_hat, hparams.quantize_channels - 1) y_target = inv_linear_quantize(y_target, hparams.quantize_channels - 1) elif is_mulaw(hparams.input_type): y_hat = P.inv_mulaw( y_hat.view(-1).cpu().data.numpy(), hparams.quantize_channels) y_target = P.inv_mulaw(y_target, hparams.quantize_channels) else: y_hat = y_hat.view(-1).cpu().data.numpy() # Save audio os.makedirs(eval_dir, exist_ok=True) path = join(eval_dir, "step{:09d}_predicted.wav".format(global_step)) librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate) path = join(eval_dir, "step{:09d}_target.wav".format(global_step)) librosa.output.write_wav(path, y_target, sr=hparams.sample_rate) # save figure path = join(eval_dir, "step{:09d}_waveplots.png".format(global_step)) save_waveplot(path, y_hat, y_target)
def collate_fn(batch): """Create batch Args: batch(tuple): List of tuples - x[0] (ndarray,int) : list of (T,) - x[1] (ndarray,int) : list of (T, D) - x[2] (ndarray,int) : list of (1,), speaker id Returns: tuple: Tuple of batch - x (FloatTensor) : Network inputs (B, C, T) - y (LongTensor) : Network targets (B, T, 1) """ local_conditioning = len(batch[0]) >= 2 and hparams.cin_channels > 0 global_conditioning = len(batch[0]) >= 3 and hparams.gin_channels > 0 if hparams.max_time_sec is not None: max_time_steps = int(hparams.max_time_sec * hparams.sample_rate) elif hparams.max_time_steps is not None: max_time_steps = hparams.max_time_steps else: max_time_steps = None # Time resolution adjustment cin_pad = hparams.cin_pad if local_conditioning: new_batch = [] for idx in range(len(batch)): x, c, g = batch[idx] if hparams.upsample_conditional_features: assert_ready_for_upsampling(x, c, cin_pad=0) if max_time_steps is not None: max_steps = ensure_divisible(max_time_steps, audio.get_hop_size(), True) if len(x) > max_steps: max_time_frames = max_steps // audio.get_hop_size() s = np.random.randint( cin_pad, len(c) - max_time_frames - cin_pad) ts = s * audio.get_hop_size() x = x[ts:ts + audio.get_hop_size() * max_time_frames] c = c[s - cin_pad:s + max_time_frames + cin_pad, :] assert_ready_for_upsampling(x, c, cin_pad=cin_pad) else: x, c = audio.adjust_time_resolution(x, c) if max_time_steps is not None and len(x) > max_time_steps: s = np.random.randint(cin_pad, len(x) - max_time_steps - cin_pad) x = x[s:s + max_time_steps] c = c[s - cin_pad:s + max_time_steps + cin_pad, :] assert len(x) == len(c) new_batch.append((x, c, g)) batch = new_batch else: new_batch = [] for idx in range(len(batch)): x, c, g = batch[idx] x = audio.trim(x) if max_time_steps is not None and len(x) > max_time_steps: s = np.random.randint(0, len(x) - max_time_steps) if local_conditioning: x, c = x[s:s + max_time_steps], c[s:s + max_time_steps, :] else: x = x[s:s + max_time_steps] new_batch.append((x, c, g)) batch = new_batch # Lengths input_lengths = [len(x[0]) for x in batch] max_input_len = max(input_lengths) # (B, T, C) # pad for time-axis if is_mulaw_quantize(hparams.input_type): padding_value = P.mulaw_quantize(0, mu=hparams.quantize_channels - 1) if hparams.manual_scalar_input: x_batch = np.array([ _pad_2d(x[0].reshape(-1, 1), max_input_len, 0, padding_value) for x in batch ], dtype=np.float32) else: x_batch = np.array([ _pad_2d( to_categorical(x[0], num_classes=hparams.quantize_channels), max_input_len, 0, padding_value) for x in batch ], dtype=np.float32) elif is_linear_quantize(hparams.input_type): padding_value = linear_quantize(0, hparams.quantize_channels - 1) if hparams.manual_scalar_input: x_batch = np.array([ _pad_2d(x[0].reshape(-1, 1), max_input_len, 0, padding_value) for x in batch ], dtype=np.float32) else: x_batch = np.array([ _pad_2d( to_categorical(x[0], num_classes=hparams.quantize_channels), max_input_len, 0, padding_value) for x in batch ], dtype=np.float32) else: x_batch = np.array( [_pad_2d(x[0].reshape(-1, 1), max_input_len) for x in batch], dtype=np.float32) assert len(x_batch.shape) == 3 # (B, T) if is_mulaw_quantize(hparams.input_type): padding_value = P.mulaw_quantize(0, mu=hparams.quantize_channels - 1) y_batch = np.array([ _pad(x[0], max_input_len, constant_values=padding_value) for x in batch ], dtype=np.int) elif is_linear_quantize(hparams.input_type): padding_value = linear_quantize(0, hparams.quantize_channels - 1) y_batch = np.array([ _pad(x[0], max_input_len, constant_values=padding_value) for x in batch ], dtype=np.int) else: y_batch = np.array([_pad(x[0], max_input_len) for x in batch], dtype=np.float32) assert len(y_batch.shape) == 2 # (B, T, D) if local_conditioning: max_len = max([len(x[1]) for x in batch]) c_batch = np.array([_pad_2d(x[1], max_len) for x in batch], dtype=np.float32) assert len(c_batch.shape) == 3 # (B x C x T) c_batch = torch.FloatTensor(c_batch).transpose(1, 2).contiguous() else: c_batch = None if global_conditioning: g_batch = torch.LongTensor([x[2] for x in batch]) else: g_batch = None # Covnert to channel first i.e., (B, C, T) x_batch = torch.FloatTensor(x_batch).transpose(1, 2).contiguous() # Add extra axis if is_mulaw_quantize(hparams.input_type) or is_linear_quantize( hparams.input_type): y_batch = torch.LongTensor(y_batch).unsqueeze(-1).contiguous() else: y_batch = torch.FloatTensor(y_batch).unsqueeze(-1).contiguous() input_lengths = torch.LongTensor(input_lengths) return x_batch, y_batch, c_batch, g_batch, input_lengths