def forward(self, y_hat, mu_q, scale_q, mask, sample_T=32): if hparams.output_type == 'Gaussian': # teacher p,student q mu_p, scale_p = y_hat[:, :1, :], torch.exp(y_hat[:, 1:, :]) loss = torch.log(scale_p / scale_q) + (scale_q ** 2 - scale_p ** 2 + (mu_q - mu_p) ** 2) / ( 2 * scale_p ** 2) # loss += torch.log(scale_q / scale_p) + (scale_p ** 2 - scale_q ** 2 + (mu_q - mu_p) ** 2) / ( 2 * scale_q ** 2) # loss /= 2 loss += self.lambda_*(torch.log(scale_p)-torch.log(scale_q))**2 kl_loss = torch.sum(loss[:,:,:-1] * mask.permute(0,2,1)) / mask.sum() return kl_loss elif hparams.output_type == "MOL": h_pt_ps = 0 for i in range(sample_T): u = torch.zeros(mu_q.size()).uniform_(1e-5, 1 - 1e-5) if use_cuda: u = u.cuda() z = torch.log(u) - torch.log(1 - u) student_predict = mu_q + z * scale_q assert student_predict.requires_grad is True student_predict = student_predict.permute(0, 2, 1) teacher_log_p = discretized_mix_logistic_loss(y_hat[:, :, :-1], student_predict[:, 1:, :], reduce=False) h_pt_ps += torch.sum(teacher_log_p * mask) / mask.sum() # compute h_ps a = scale_q.permute(0, 2, 1) h_ps = torch.sum((torch.log(a[:, 1:, :]) + 2) * mask) / (mask.sum()) # compute kl loss cross_entropy = h_pt_ps / sample_T kl_loss = cross_entropy - h_ps return kl_loss
def __init__(self, hparams): super(DiscretizedMixturelogisticLoss, self).__init__() self.quantize_channels = hparams.quantize_channels self.log_scale_min = hparams.log_scale_min self.discretized_mix_logistic_loss = discretized_mix_logistic_loss(num_classes=hparams.quantize_channels, log_scale_min=hparams.log_scale_min, reduce=False) self.reduce_sum_op = P.ReduceSum() self.reduce_mean_op = P.ReduceMean()
def test_mixture(): np.random.seed(1234) x, sr = librosa.load(pysptk.util.example_audio_file(), sr=None) assert sr == 16000 T = len(x) x = x.reshape(1, T, 1) y = Variable(torch.from_numpy(x)).float() y_hat = Variable(torch.rand(1, 30, T)).float() print(y.shape, y_hat.shape) loss = discretized_mix_logistic_loss(y_hat, y) print(loss) loss = discretized_mix_logistic_loss(y_hat, y, reduce=False) print(loss.size(), y.size()) assert loss.size() == y.size() y = sample_from_discretized_mix_logistic(y_hat) print(y.shape)
def forward(self, input, target, lengths=None, mask=None, max_len=None): if lengths is None and mask is None: raise RuntimeError("Should provide either lengths or mask") # (B, T, 1) if mask is None: mask = sequence_mask(lengths, max_len).unsqueeze(-1) # (B, T, 1) mask_ = mask.expand_as(target) losses = discretized_mix_logistic_loss( input, target, num_classes=hparams.quantize_channels, log_scale_min=hparams.log_scale_min, reduce=False) assert losses.size() == target.size() return ((losses * mask_).sum()) / mask_.sum()
def __train_step(phase, epoch, global_step, global_test_step, teacher, student, optimizer, writer, x, y, c, g, input_lengths, checkpoint_dir, eval_dir=None, do_eval=False, ema=None): sanity_check(teacher, c, g) sanity_check(student, c, g) # x : (B, C, T) # y : (B, T, 1) # c : (B, C, T) # g : (B,) train = (phase == "train") clip_thresh = hparams.clip_thresh if train: teacher.eval() # set teacher as eval mode student.train() step = global_step else: student.eval() step = global_test_step # ---------------------- the parallel wavenet use constant learning rate = 0.0002 # Learning rate schedule current_lr = hparams.initial_learning_rate if train and hparams.lr_schedule is not None: lr_schedule_f = getattr(lrschedule, hparams.lr_schedule) current_lr = lr_schedule_f(hparams.initial_learning_rate, step, **hparams.lr_schedule_kwargs) if gpu_count > 1: for param_group in optimizer.module.param_groups: param_group['lr'] = current_lr else: for param_group in optimizer.param_groups: param_group['lr'] = current_lr optimizer.zero_grad() # Prepare data x, y = Variable(x), Variable(y, requires_grad=False) c = Variable(c) if c is not None else None g = Variable(g) if g is not None else None input_lengths = Variable(input_lengths) if use_cuda: x, y = x.cuda(), y.cuda() input_lengths = input_lengths.cuda() c = c.cuda() if c is not None else None g = g.cuda() if g is not None else None # (B, T, 1) mask = sequence_mask(input_lengths, max_len=x.size(-1)).unsqueeze(-1) mask = mask[:, 1:, :] # apply the student model with stacked iaf layers and return mu,scale # u = Variable(torch.from_numpy(np.random.uniform(1e-5, 1 - 1e-5, x.size())).float().cuda(), requires_grad=False) # z = torch.log(u) - torch.log(1 - u) u = Variable(torch.zeros(*x.size()).uniform_(1e-5, 1 - 1e-5), requires_grad=False).cuda() z = torch.log(u) - torch.log(1 - u) predict, mu, scale = student(z, c=c, g=g, softmax=False) m, s = mu, scale # mu, scale = to_numpy(mu), to_numpy(scale) # TODO sample times, change to 300 or 400 sample_T, kl_loss_sum = 16, 0 power_loss_sum = 0 y_hat = teacher(predict, c=c, g=g) # y_hat: (B x C x T) teacher: 10-mixture-logistic h_pt_ps = 0 # TODO add some constrain on scale ,we want it to be small? for i in range(sample_T): # https://en.wikipedia.org/wiki/Logistic_distribution u = Variable(torch.zeros(*x.size()).uniform_(1e-5, 1 - 1e-5), requires_grad=False).cuda() z = torch.log(u) - torch.log(1 - u) student_predict = m + s * z # predicted wave # student_predict.clamp(-0.99, 0.99) student_predict = student_predict.permute(0, 2, 1) _, teacher_log_p = discretized_mix_logistic_loss( y_hat[:, :, :-1], student_predict[:, 1:, :], reduce=False) h_pt_ps += torch.sum(teacher_log_p * mask) / mask.sum() student_predict = student_predict.permute(0, 2, 1) power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=512, hop_length=128) power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=256, hop_length=64) power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=2048, hop_length=512) power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=1024, hop_length=256) power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=128, hop_length=32) a = s.permute(0, 2, 1) h_ps = torch.sum((torch.log(a[:, 1:, :]) + 2) * mask) / (mask.sum()) cross_entropy = h_pt_ps / (sample_T) kl_loss = cross_entropy - 2 * h_ps # power_loss_sum += get_power_loss_torch(predict, x, n_fft=1024, hop_length=64) # power_loss_sum += get_power_loss_torch(predict, x, n_fft=1024, hop_length=128) # power_loss_sum += get_power_loss_torch(predict, x, n_fft=1024, hop_length=256) # power_loss_sum += get_power_loss_torch(predict, x, n_fft=1024, hop_length=512) power_loss = power_loss_sum / (5 * sample_T) loss = kl_loss + power_loss if step > 0 and step % 20 == 0: print('power_loss={}, mean_scale={}, mean_mu={},kl_loss={},loss={}'. format(to_numpy(power_loss), np.mean(to_numpy(s)), np.mean(to_numpy(m)), to_numpy(kl_loss), to_numpy(loss))) if train and step > 0 and step % hparams.checkpoint_interval == 0: save_states(step, writer, y_hat=y_hat, y=y, y_student=predict, input_lengths=input_lengths, mu=m, checkpoint_dir=checkpoint_dir) if step % (5 * hparams.checkpoint_interval) == 0: save_checkpoint(student, optimizer, step, checkpoint_dir, epoch) if do_eval and False: # NOTE: use train step (i.e., global_step) for filename # eval_model(global_step, writer, model, y, c, g, input_lengths, eval_dir, ema) eval_model(global_step, writer, student, y, c, g, input_lengths, eval_dir, ema) # Update if train: loss.backward() if clip_thresh > 0: grad_norm = torch.nn.utils.clip_grad_norm(student.parameters(), clip_thresh) if gpu_count > 1: optimizer.module.step() else: optimizer.step() # update moving average if ema is not None: for name, param in student.named_parameters(): if name in ema.shadow: ema.update(name, param.data) # Logs writer.add_scalar("{} loss".format(phase), float(loss.data[0]), step) writer.add_scalar("{} _hps".format(phase), float(h_ps.data[0]), step) writer.add_scalar("{} h_pt_ps".format(phase), float(cross_entropy.data[0]), step) writer.add_scalar("{} kl_loss".format(phase), float(kl_loss.data[0]), step) writer.add_scalar("{} power_loss".format(phase), float(power_loss.data[0]), step) if train: if clip_thresh > 0: writer.add_scalar("gradient norm", grad_norm, step) # writer.add_scalar("gradient norm", grad_norm, step) # writer.add_scalar("learning rate", current_lr, step) return loss.data[0]
def __train_step(phase, epoch, global_step, global_test_step, teacher, student, optimizer, writer, x, y, c, g, input_lengths, checkpoint_dir, eval_dir=None, do_eval=False, ema=None): sanity_check(teacher, c, g) sanity_check(student, c, g) # x : (B, C, T) # y : (B, T, 1) # c : (B, C, T) # g : (B,) train = (phase == "train") clip_thresh = hparams.clip_thresh if train: teacher.eval() # set teacher as eval mode student.train() step = global_step else: student.eval() step = global_test_step # ---------------------- the parallel wavenet use constant learning rate = 0.0002 # Learning rate schedule # current_lr = hparams.initial_learning_rate # if train and hparams.lr_schedule is not None: # lr_schedule_f = getattr(lrschedule, hparams.lr_schedule) # current_lr = lr_schedule_f( # hparams.initial_learning_rate, step, **hparams.lr_schedule_kwargs) # if gpu_count>1: # for param_group in optimizer.module.param_groups: # param_group['lr'] = current_lr # else: # for param_group in optimizer.param_groups: # param_group['lr'] = current_lr optimizer.zero_grad() cross_entorpy = nn.CrossEntropyLoss() # Prepare data x, y = Variable(x), Variable(y, requires_grad=False) c = Variable(c) if c is not None else None g = Variable(g) if g is not None else None input_lengths = Variable(input_lengths) if use_cuda: x, y = x.cuda(), y.cuda() input_lengths = input_lengths.cuda() c = c.cuda() if c is not None else None g = g.cuda() if g is not None else None # (B, T, 1) mask = sequence_mask(input_lengths, max_len=x.size(-1)).unsqueeze(-1) mask = mask[:, 1:, :] # mask.expand_as(y) # apply the student model with stacked iaf layers and return mu,scale z = Variable( torch.from_numpy(np.random.logistic(0, 1, size=x.size())).float()).cuda() mu, scale = student(z, c=c, g=g, softmax=False) m, s = mu, scale mu, scale = to_numpy(mu), to_numpy(scale) kl_loss, h_s = 0, 0 _h_pt_ps = 0 m = m.clamp(-0.999, 0.999) sample_T, kl_loss_sum = 5, Variable(torch.FloatTensor(1).float(), requires_grad=True).cuda() power_loss_sum = 0 for i in range(sample_T): z = np.random.logistic(0, 1, x.shape) student_predict = m + s * to_variable(z) # predicted wave # sp = student_predict.clamp(-0.99, 0.99) student_predict = student_predict.clamp(-0.99, 0.99) y_hat = teacher(student_predict, c=c, g=g) # y_hat: (B x C x T) teacher: 10-mixture-logistic # sample from teacher distribution teacher_predict = sample_from_discretized_mix_logistic(y_hat) student_predict = student_predict.permute(0, 2, 1) _, teacher_log_p = discretized_mix_logistic_loss( y_hat[:, :, :-1], student_predict[:, 1:, :], reduce=False) # -log(Pt) # h_pt_ps = torch.sum(teacher_log_p * p_s * mask) # / mask.sum() h_pt_ps = torch.sum(teacher_log_p * mask) / mask.sum() # h_pt_ps = F.cross_entropy(student_predict,teacher_predict) student_predict = student_predict.permute(0, 2, 1) power_loss_sum += get_power_loss_torch(student_predict, x) # _h_pt_ps += torch.sum(teacher_log_p) # / mask.sum() a = s.permute(0, 2, 1) # h_ps = torch.sum(torch.log(p_s) * mask) # / mask.sum() # cross_entorpy = F.cross_entropy(teacher_predict,student_predict) h_ps = torch.sum((teacher_log_p - (torch.log(a[:, 1:, :]) + 2)) * mask) / mask.sum() kl_loss_sum += h_ps #+ h_pt_ps kl_loss = kl_loss_sum / (hparams.batch_size * sample_T) power_loss = power_loss_sum / (hparams.batch_size * sample_T) loss = kl_loss # + power_loss rs = kl_loss.cpu().data.numpy() if rs == np.isinf(rs): print('inf detected') else: print('power_loss={}, mean_scale={}, mean_mu={},kl_loss={},loss={}'. format(to_numpy(power_loss), np.mean(scale), np.mean(mu), to_numpy(kl_loss), to_numpy(loss))) if train and step > 0 and step % hparams.checkpoint_interval == 0: save_states(step, writer, y_hat, y, student_predict, input_lengths, checkpoint_dir) if step % (5 * hparams.checkpoint_interval) == 0: save_checkpoint(student, optimizer, step, checkpoint_dir, epoch) if do_eval and False: # NOTE: use train step (i.e., global_step) for filename # eval_model(global_step, writer, model, y, c, g, input_lengths, eval_dir, ema) eval_model(global_step, writer, student, y, c, g, input_lengths, eval_dir, ema) # Update if train: loss.backward() if clip_thresh > 0: grad_norm = torch.nn.utils.clip_grad_norm(student.parameters(), clip_thresh) if gpu_count > 1: optimizer.module.step() else: optimizer.step() # update moving average if ema is not None: for name, param in student.named_parameters(): if name in ema.shadow: ema.update(name, param.data) # Logs writer.add_scalar("{} loss".format(phase), float(loss.data[0]), step) writer.add_scalar("{} _hps".format(phase), float(h_ps.data[0]), step) writer.add_scalar("{} h_pt_ps".format(phase), float(h_pt_ps.data[0]), step) writer.add_scalar("{} kl_loss".format(phase), float(kl_loss.data[0]), step) if train: if clip_thresh > 0: writer.add_scalar("gradient norm", grad_norm, step) # writer.add_scalar("gradient norm", grad_norm, step) # writer.add_scalar("learning rate", current_lr, step) return loss.data[0]