def save_states(global_step, writer, y_hat, student_hat, y, input_lengths, checkpoint_dir=None): print("Save intermediate states at step {}".format(global_step)) idx = np.random.randint(0, len(y_hat)) length = input_lengths[idx].data.cpu().item() # (B, C, T) if y_hat.dim() == 4: y_hat = y_hat.squeeze(-1) if is_mulaw_quantize(hparams.input_type): # (B, T) y_hat = F.softmax(y_hat, dim=1).max(1)[1] # (T,) y_hat = y_hat[idx].data.cpu().long().numpy() y = y[idx].view(-1).data.cpu().long().numpy() y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels) y = P.inv_mulaw_quantize(y, hparams.quantize_channels) else: # (B, T) if hparams.use_gaussian: y_hat = y_hat.transpose(1, 2) y_hat = sample_from_gaussian(y_hat, log_scale_min=hparams.log_scale_min) else: y_hat = sample_from_discretized_mix_logistic( y_hat, log_scale_min=hparams.log_scale_min) # (T,) y_hat = y_hat[idx].view(-1).data.cpu().numpy() y = y[idx].view(-1).data.cpu().numpy() student_hat = student_hat[idx].view(-1).data.cpu().numpy() if is_mulaw(hparams.input_type): y_hat = P.inv_mulaw(y_hat, hparams.quantize_channels) y = P.inv_mulaw(y, hparams.quantize_channels) student_hat = P.inv_mulaw(student_hat, hparams.quantize_channels) # Mask by length y_hat[length:] = 0 y[length:] = 0 student_hat[length:] = 0 # Save audio audio_dir = join(checkpoint_dir, "audio") os.makedirs(audio_dir, exist_ok=True) path = join(audio_dir, "step{:09d}_teacher.wav".format(global_step)) librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate) path = join(audio_dir, "step{:09d}_student.wav".format(global_step)) librosa.output.write_wav(path, student_hat, sr=hparams.sample_rate) path = join(audio_dir, "step{:09d}_target.wav".format(global_step)) librosa.output.write_wav(path, y, sr=hparams.sample_rate)
def save_states(global_step, writer, y_hat, y, input_lengths, checkpoint_dir=None): print("Save intermediate states at step {}".format(global_step)) idx = np.random.randint(0, len(y_hat)) length = input_lengths[idx].data.cpu().item() # (B, C, T) if y_hat.dim() == 4: y_hat = y_hat.squeeze(-1) if is_mulaw_quantize(wavenet_hparams.input_type): # (B, T) y_hat = F.softmax(y_hat, dim=1).max(1)[1] # (T,) y_hat = y_hat[idx].data.cpu().long().numpy() y = y[idx].view(-1).data.cpu().long().numpy() y_hat = P.inv_mulaw_quantize(y_hat, wavenet_hparams.quantize_channels - 1) y = P.inv_mulaw_quantize(y, wavenet_hparams.quantize_channels - 1) else: # (B, T) if wavenet_hparams.output_distribution == "Logistic": y_hat = sample_from_discretized_mix_logistic( y_hat, log_scale_min=wavenet_hparams.log_scale_min) elif wavenet_hparams.output_distribution == "Normal": y_hat = sample_from_mix_gaussian( y_hat, log_scale_min=wavenet_hparams.log_scale_min) else: assert False # (T,) y_hat = y_hat[idx].view(-1).data.cpu().numpy() y = y[idx].view(-1).data.cpu().numpy() if is_mulaw(wavenet_hparams.input_type): y_hat = P.inv_mulaw(y_hat, wavenet_hparams.quantize_channels) y = P.inv_mulaw(y, wavenet_hparams.quantize_channels) # Mask by length y_hat[length:] = 0 y[length:] = 0 # Save audio audio_dir = join(checkpoint_dir, "intermediate", "audio") os.makedirs(audio_dir, exist_ok=True) path = join(audio_dir, "step{:09d}_predicted.wav".format(global_step)) # librosa.output.write_wav(path, y_hat, sr=wavenet_hparams.sample_rate) sf.write(path, y_hat, samplerate=wavenet_hparams.sample_rate) path = join(audio_dir, "step{:09d}_target.wav".format(global_step)) # librosa.output.write_wav(path, y, sr=wavenet_hparams.sample_rate) sf.write(path, y, samplerate=wavenet_hparams.sample_rate)
def test_mixture(): np.random.seed(1234) x, sr = librosa.load(pysptk.util.example_audio_file(), sr=None) assert sr == 16000 T = len(x) x = x.reshape(1, T, 1) y = Variable(torch.from_numpy(x)).float() y_hat = Variable(torch.rand(1, 30, T)).float() print(y.shape, y_hat.shape) loss = discretized_mix_logistic_loss(y_hat, y) print(loss) loss = discretized_mix_logistic_loss(y_hat, y, reduce=False) print(loss.size(), y.size()) assert loss.size() == y.size() y = sample_from_discretized_mix_logistic(y_hat) print(y.shape)
def save_states(global_step, writer, y_hat, y, y_student, input_lengths, mu=None, checkpoint_dir=None): ''' :param global_step: :param writer: :param y_hat: parameters output by teachery_hat是教师结果 :param y: target :param y_student: student output :param input_lengths: :param mu: student mu :param checkpoint_dir: :return: ''' print("Save intermediate states at step {}".format(global_step)) idx = np.random.randint(0, len(y_hat)) length = input_lengths[idx].data.cpu().numpy() if mu is not None: mu = mu[idx] # (B, C, T) if y_hat.dim() == 4: y_hat = y_hat.squeeze(-1) if is_mulaw_quantize(hparams.input_type): # (B, T) y_hat = F.softmax(y_hat, dim=1).max(1)[1] # (T,) y_hat = y_hat[idx].data.cpu().long().numpy() y = y[idx].view(-1).data.cpu().long().numpy() y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels) y = P.inv_mulaw_quantize(y, hparams.quantize_channels) else: # (B, T) y_hat = sample_from_discretized_mix_logistic( y_hat, log_scale_min=hparams.log_scale_min) # (T,) y_hat = y_hat[idx].view(-1).data.cpu().numpy() y = y[idx].view(-1).data.cpu().numpy() if is_mulaw(hparams.input_type): y_hat = P.inv_mulaw(y_hat, hparams.quantize_channels) y = P.inv_mulaw(y, hparams.quantize_channels) # Mask by length y_hat[length:] = 0 y[length:] = 0 y_student = y_student.data.cpu().numpy() y_student = y_student[idx].reshape(y_student.shape[-1]) mu = to_numpy(mu) # Save audio audio_dir = join(checkpoint_dir, "audio") if global_step % 1000 == 0: audio_dir = join(checkpoint_dir, "audio") os.makedirs(audio_dir, exist_ok=True) path = join(audio_dir, "step{:09d}_teacher.wav".format(global_step)) librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate) path = join(audio_dir, "step{:09d}_target.wav".format(global_step)) librosa.output.write_wav(path, y, sr=hparams.sample_rate) path = join(audio_dir, "step{:09d}_student.wav".format(global_step)) librosa.output.write_wav(path, y_student, sr=hparams.sample_rate) # TODO save every 200 step, if global_step % 200 == 0: path = join(audio_dir, "wave_step{:09d}.png".format(global_step)) save_waveplot(path, y_student=y_student, y_target=y, y_teacher=y_hat, student_mu=mu)
def incremental_forward(self, initial_input=None, c=None, g=None, T=100, test_inputs=None, tqdm=lambda x: x, softmax=True, quantize=True, log_scale_min=-50.0): """Incremental forward step Due to linearized convolutions, inputs of shape (B x C x T) are reshaped to (B x T x C) internally and fed to the network for each time step. Input of each time step will be of shape (B x 1 x C). Args: initial_input (Tensor): Initial decoder input, (B x C x 1) c (Tensor): Local conditioning features, shape (B x C' x T) g (Tensor): Global conditioning features, shape (B x C'' or B x C''x 1) T (int): Number of time steps to generate. test_inputs (Tensor): Teacher forcing inputs (for debugging) tqdm (lamda) : tqdm softmax (bool) : Whether applies softmax or not quantize (bool): Whether quantize softmax output before feeding the network output to input for the next time step. TODO: rename log_scale_min (float): Log scale minimum value. Returns: Tensor: Generated one-hot encoded samples. B x C x T or scaler vector B x 1 x T """ self.clear_buffer() B = 1 # Note: shape should be **(B x T x C)**, not (B x C x T) opposed to # batch forward due to linealized convolution if test_inputs is not None: if self.scalar_input: if test_inputs.size(1) == 1: test_inputs = test_inputs.transpose(1, 2).contiguous() else: if test_inputs.size(1) == self.out_channels: test_inputs = test_inputs.transpose(1, 2).contiguous() B = test_inputs.size(0) if T is None: T = test_inputs.size(1) else: T = max(T, test_inputs.size(1)) # cast to int in case of numpy.int64... T = int(T) # Global conditioning if g is not None: if self.embed_speakers is not None: g = self.embed_speakers(g.view(B, -1)) # (B x gin_channels, 1) g = g.transpose(1, 2) assert g.dim() == 3 g_btc = _expand_global_features(B, T, g, bct=False) # Local conditioning if c is not None: B = c.shape[0] if self.upsample_net is not None: c = self.upsample_net(c) assert c.size(-1) == T if c.size(-1) == T: c = c.transpose(1, 2).contiguous() outputs = [] if initial_input is None: if self.scalar_input: initial_input = torch.zeros(B, 1, 1) else: initial_input = torch.zeros(B, 1, self.out_channels) initial_input[:, :, 127] = 1 # TODO: is this ok? # https://github.com/pytorch/pytorch/issues/584#issuecomment-275169567 if next(self.parameters()).is_cuda: initial_input = initial_input.cuda() else: if initial_input.size(1) == self.out_channels: initial_input = initial_input.transpose(1, 2).contiguous() current_input = initial_input for t in tqdm(range(T)): if test_inputs is not None and t < test_inputs.size(1): current_input = test_inputs[:, t, :].unsqueeze(1) else: if t > 0: current_input = outputs[-1] # Conditioning features for single time step ct = None if c is None else c[:, t, :].unsqueeze(1) gt = None if g is None else g_btc[:, t, :].unsqueeze(1) x = current_input x = self.first_conv.incremental_forward(x) skips = 0 for f in self.conv_layers: x, h = f.incremental_forward(x, ct, gt) skips += h skips *= math.sqrt(1.0 / len(self.conv_layers)) x = skips for f in self.last_conv_layers: try: x = f.incremental_forward(x) except AttributeError: x = f(x) # Generate next input by sampling if self.scalar_input: if self.output_distribution == "Logistic": x = sample_from_discretized_mix_logistic( x.view(B, -1, 1), log_scale_min=log_scale_min) elif self.output_distribution == "Normal": x = sample_from_mix_gaussian(x.view(B, -1, 1), log_scale_min=log_scale_min) else: assert False else: x = F.softmax(x.view(B, -1), dim=1) if softmax else x.view( B, -1) if quantize: dist = torch.distributions.OneHotCategorical(x) x = dist.sample() outputs += [x.data] # T x B x C outputs = torch.stack(outputs) # B x C x T outputs = outputs.transpose(0, 1).transpose(1, 2).contiguous() self.clear_buffer() return outputs
def __train_step(phase, epoch, global_step, global_test_step, teacher, student, optimizer, writer, x, y, c, g, input_lengths, checkpoint_dir, eval_dir=None, do_eval=False, ema=None): sanity_check(teacher, c, g) sanity_check(student, c, g) # x : (B, C, T) # y : (B, T, 1) # c : (B, C, T) # g : (B,) train = (phase == "train") clip_thresh = hparams.clip_thresh if train: teacher.eval() # set teacher as eval mode student.train() step = global_step else: student.eval() step = global_test_step # ---------------------- the parallel wavenet use constant learning rate = 0.0002 # Learning rate schedule # current_lr = hparams.initial_learning_rate # if train and hparams.lr_schedule is not None: # lr_schedule_f = getattr(lrschedule, hparams.lr_schedule) # current_lr = lr_schedule_f( # hparams.initial_learning_rate, step, **hparams.lr_schedule_kwargs) # if gpu_count>1: # for param_group in optimizer.module.param_groups: # param_group['lr'] = current_lr # else: # for param_group in optimizer.param_groups: # param_group['lr'] = current_lr optimizer.zero_grad() cross_entorpy = nn.CrossEntropyLoss() # Prepare data x, y = Variable(x), Variable(y, requires_grad=False) c = Variable(c) if c is not None else None g = Variable(g) if g is not None else None input_lengths = Variable(input_lengths) if use_cuda: x, y = x.cuda(), y.cuda() input_lengths = input_lengths.cuda() c = c.cuda() if c is not None else None g = g.cuda() if g is not None else None # (B, T, 1) mask = sequence_mask(input_lengths, max_len=x.size(-1)).unsqueeze(-1) mask = mask[:, 1:, :] # mask.expand_as(y) # apply the student model with stacked iaf layers and return mu,scale z = Variable( torch.from_numpy(np.random.logistic(0, 1, size=x.size())).float()).cuda() mu, scale = student(z, c=c, g=g, softmax=False) m, s = mu, scale mu, scale = to_numpy(mu), to_numpy(scale) kl_loss, h_s = 0, 0 _h_pt_ps = 0 m = m.clamp(-0.999, 0.999) sample_T, kl_loss_sum = 5, Variable(torch.FloatTensor(1).float(), requires_grad=True).cuda() power_loss_sum = 0 for i in range(sample_T): z = np.random.logistic(0, 1, x.shape) student_predict = m + s * to_variable(z) # predicted wave # sp = student_predict.clamp(-0.99, 0.99) student_predict = student_predict.clamp(-0.99, 0.99) y_hat = teacher(student_predict, c=c, g=g) # y_hat: (B x C x T) teacher: 10-mixture-logistic # sample from teacher distribution teacher_predict = sample_from_discretized_mix_logistic(y_hat) student_predict = student_predict.permute(0, 2, 1) _, teacher_log_p = discretized_mix_logistic_loss( y_hat[:, :, :-1], student_predict[:, 1:, :], reduce=False) # -log(Pt) # h_pt_ps = torch.sum(teacher_log_p * p_s * mask) # / mask.sum() h_pt_ps = torch.sum(teacher_log_p * mask) / mask.sum() # h_pt_ps = F.cross_entropy(student_predict,teacher_predict) student_predict = student_predict.permute(0, 2, 1) power_loss_sum += get_power_loss_torch(student_predict, x) # _h_pt_ps += torch.sum(teacher_log_p) # / mask.sum() a = s.permute(0, 2, 1) # h_ps = torch.sum(torch.log(p_s) * mask) # / mask.sum() # cross_entorpy = F.cross_entropy(teacher_predict,student_predict) h_ps = torch.sum((teacher_log_p - (torch.log(a[:, 1:, :]) + 2)) * mask) / mask.sum() kl_loss_sum += h_ps #+ h_pt_ps kl_loss = kl_loss_sum / (hparams.batch_size * sample_T) power_loss = power_loss_sum / (hparams.batch_size * sample_T) loss = kl_loss # + power_loss rs = kl_loss.cpu().data.numpy() if rs == np.isinf(rs): print('inf detected') else: print('power_loss={}, mean_scale={}, mean_mu={},kl_loss={},loss={}'. format(to_numpy(power_loss), np.mean(scale), np.mean(mu), to_numpy(kl_loss), to_numpy(loss))) if train and step > 0 and step % hparams.checkpoint_interval == 0: save_states(step, writer, y_hat, y, student_predict, input_lengths, checkpoint_dir) if step % (5 * hparams.checkpoint_interval) == 0: save_checkpoint(student, optimizer, step, checkpoint_dir, epoch) if do_eval and False: # NOTE: use train step (i.e., global_step) for filename # eval_model(global_step, writer, model, y, c, g, input_lengths, eval_dir, ema) eval_model(global_step, writer, student, y, c, g, input_lengths, eval_dir, ema) # Update if train: loss.backward() if clip_thresh > 0: grad_norm = torch.nn.utils.clip_grad_norm(student.parameters(), clip_thresh) if gpu_count > 1: optimizer.module.step() else: optimizer.step() # update moving average if ema is not None: for name, param in student.named_parameters(): if name in ema.shadow: ema.update(name, param.data) # Logs writer.add_scalar("{} loss".format(phase), float(loss.data[0]), step) writer.add_scalar("{} _hps".format(phase), float(h_ps.data[0]), step) writer.add_scalar("{} h_pt_ps".format(phase), float(h_pt_ps.data[0]), step) writer.add_scalar("{} kl_loss".format(phase), float(kl_loss.data[0]), step) if train: if clip_thresh > 0: writer.add_scalar("gradient norm", grad_norm, step) # writer.add_scalar("gradient norm", grad_norm, step) # writer.add_scalar("learning rate", current_lr, step) return loss.data[0]