def greedy_decoding_and_AAS(self, inputs, targets, input_percentages, target_sizes, mask, transcript_prob=0.001): inputs = _get_variable_volatile(inputs) N = inputs.size(0) # unflatten targets split_targets = [] offset = 0 for size in target_sizes: split_targets.append(targets[offset:offset + size]) offset += size # step 1) Decoding to get wer & cer enhanced = self.G(inputs) prob = self.ASR(enhanced) prob = prob.transpose(0, 1) T = prob.size(0) sizes = input_percentages.mul_(int(T)).int() decoded_output, _ = self.decoder.decode(prob.data, sizes) target_strings = self.decoder.convert_to_strings(split_targets) we, ce, total_word, total_char = 0, 0, 0, 0 for x in range(len(target_strings)): decoding, reference = decoded_output[x][0], target_strings[x][0] nChar = len(reference) nWord = len(reference.split()) we_i = self.decoder.wer(decoding, reference) ce_i = self.decoder.cer(decoding, reference) we += we_i ce += ce_i total_word += nWord total_char += nChar if (random.uniform(0, 1) < transcript_prob): print('reference = ' + reference) print('decoding = ' + decoding) print('wer = ' + str(we_i / float(nWord)) + ', cer = ' + str(ce_i / float(nChar))) wer = we / total_word cer = ce / total_word # step 2) get adversarial loss (for noisy data only) ae_ny = self.D(enhanced) l_adv_ny, nElement = self.diffLoss(ae_ny, enhanced, mask) # normalized inside function l_adv_ny = l_adv_ny * self.config.w_adversarial # step 3) get CTC loss targets = _get_variable_volatile(targets, cuda=False) sizes = _get_variable_volatile(sizes, cuda=False) target_sizes = _get_variable_volatile(target_sizes, cuda=False) l_CTC = self.config.w_acoustic * self.CTCLoss(prob, targets, sizes, target_sizes) / N return l_CTC, l_adv_ny, nElement, wer, cer, total_word, total_char
def greedy_decoding_and_CTCLoss(self, inputs, targets, input_percentages, target_sizes, transcript_prob=0.001): inputs = _get_variable_volatile(inputs) N = inputs.size(0) # unflatten targets split_targets = [] offset = 0 for size in target_sizes: split_targets.append(targets[offset:offset + size]) offset += size # step 1) Decoding to get wer & cer enhanced = self.G(inputs) prob = self.ASR(enhanced) prob = prob.transpose(0, 1) T = prob.size(0) sizes = input_percentages.mul_(int(T)).int() decoded_output, _ = self.decoder.decode(prob.data, sizes) target_strings = self.decoder.convert_to_strings(split_targets) we, ce, total_word, total_char = 0, 0, 0, 0 for x in range(len(target_strings)): decoding, reference = decoded_output[x][0], target_strings[x][0] nChar = len(reference) nWord = len(reference.split()) we_i = self.decoder.wer(decoding, reference) ce_i = self.decoder.cer(decoding, reference) we += we_i ce += ce_i total_word += nWord total_char += nChar if (random.uniform(0, 1) < transcript_prob): print('reference = ' + reference) print('decoding = ' + decoding) print('wer = ' + str(we_i / float(nWord)) + ', cer = ' + str(ce_i / float(nChar))) wer = we / total_word cer = ce / total_word # step 2) get CTC loss targets = _get_variable_volatile(targets, cuda=False) sizes = _get_variable_volatile(sizes, cuda=False) target_sizes = _get_variable_volatile(target_sizes, cuda=False) loss = self.CTCLoss(prob, targets, sizes, target_sizes) loss = loss / N return loss, wer, cer, total_word, total_char
def train(self): # Setting optimizer_g = torch.optim.Adam(self.G.parameters(), lr=self.config.lr, betas=(self.beta1, self.beta2), amsgrad=True) for iter in trange(self.config.start_iter, self.config.max_iter): # Train data_list = self.data_loader.next(cl_ny='ny', type='train') inputs, cleans, mask = _get_variable_nograd( data_list[0]), _get_variable_nograd( data_list[1]), _get_variable_nograd(data_list[2]) # forward outputs = self.G(inputs) dce, nElement = self.diffLoss( outputs, cleans, mask) # already normalized inside function # backward self.zero_grad_all() dce.backward() optimizer_g.step() # log #pdb.set_trace() if (iter + 1) % self.config.log_iter == 0: str_loss = "[{}/{}] (train) DCE: {:.7f}".format( iter, self.config.max_iter, dce.data[0]) print(str_loss) self.logFile.write(str_loss + '\n') self.logFile.flush() if (iter + 1) % self.config.save_iter == 0: self.G.eval() # Measure performance on training subset self.dce_tr.reset() self.wer_tr.reset() self.cer_tr.reset() for _ in trange(0, len(self.data_loader.trsub_dl)): data_list = self.data_loader.next(cl_ny='ny', type='trsub') inputs, cleans, mask, targets, input_percentages, target_sizes = \ _get_variable_volatile(data_list[0]), _get_variable_volatile(data_list[1]), _get_variable_volatile(data_list[2]), \ data_list[3], data_list[4], data_list[5] outputs = self.G(inputs) dce, nElement = self.diffLoss( outputs, cleans, mask) # already normalized inside function self.dce_tr.update(dce.data[0], nElement) # Greedy decodoing wer, cer, nWord, nChar = self.greedy_decoding( inputs, targets, input_percentages, target_sizes) self.wer_tr.update(wer, nWord) self.cer_tr.update(cer, nChar) str_loss = "[{}/{}] (training subset) DCE: {:.7f}".format( iter, self.config.max_iter, self.dce_tr.avg) print(str_loss) self.logFile.write(str_loss + '\n') str_loss = "[{}/{}] (training subset) WER: {:.7f}, CER: {:.7f}".format( iter, self.config.max_iter, self.wer_tr.avg * 100, self.cer_tr.avg * 100) print(str_loss) self.logFile.write(str_loss + '\n') # Measure performance on validation data self.dce_val.reset() self.wer_val.reset() self.cer_val.reset() for _ in trange(0, len(self.data_loader.val_dl)): data_list = self.data_loader.next(cl_ny='ny', type='val') inputs, cleans, mask, targets, input_percentages, target_sizes = \ _get_variable_volatile(data_list[0]), _get_variable_volatile(data_list[1]), _get_variable_volatile(data_list[2]), \ data_list[3], data_list[4], data_list[5] outputs = self.G(inputs) dce, nElement = self.diffLoss( outputs, cleans, mask) # already normalized inside function self.dce_val.update(dce.data[0], nElement) # Greedy decodoing wer, cer, nWord, nChar = self.greedy_decoding( inputs, targets, input_percentages, target_sizes) self.wer_val.update(wer, nWord) self.cer_val.update(cer, nChar) str_loss = "[{}/{}] (validation) DCE: {:.7f}".format( iter, self.config.max_iter, self.dce_val.avg) print(str_loss) self.logFile.write(str_loss + '\n') str_loss = "[{}/{}] (validation) WER: {:.7f}, CER: {:.7f}".format( iter, self.config.max_iter, self.wer_val.avg * 100, self.cer_val.avg * 100) print(str_loss) self.logFile.write(str_loss + '\n') self.G.train() # end of validation self.logFile.flush() # Save model if (len(self.savename_G) > 0): # do not remove here if os.path.exists(self.savename_G): os.remove(self.savename_G) # remove previous model self.savename_G = '{}/G_{}.pth'.format(self.model_dir, iter) torch.save(self.G.state_dict(), self.savename_G) if (self.G.loss_stop > self.wer_val.avg): self.G.loss_stop = self.wer_val.avg savename_G_valmin_prev = '{}/G_valmin_{}.pth'.format( self.model_dir, self.valmin_iter) if os.path.exists(savename_G_valmin_prev): os.remove( savename_G_valmin_prev) # remove previous model print('save model for this checkpoint') savename_G_valmin = '{}/G_valmin_{}.pth'.format( self.model_dir, iter) copyfile(self.savename_G, savename_G_valmin) self.valmin_iter = iter
def train(self): # Setting optimizer_g = torch.optim.Adam(self.G.parameters(), lr=self.config.lr, betas=(self.beta1, self.beta2), amsgrad=True) optimizer_d = torch.optim.Adam(self.D.parameters(), lr=self.config.lr, betas=(self.beta1, self.beta2), amsgrad=True) for iter in trange(self.config.start_iter, self.config.max_iter): self.zero_grad_all() # Train # Noisy data data_list = self.data_loader.next(cl_ny='ny', type='train') #inputs, targets, input_percentages, target_sizes, mask = \ # _get_variable_nograd(data_list[0]), _get_variable_nograd(data_list[1], cuda=False), data_list[2], _get_variable_nograd(data_list[3], cuda=False), _get_variable_nograd(data_list[4]) mixture, cleans, mask = \ _get_variable_nograd(data_list[0]), _get_variable_nograd(data_list[1]), _get_variable_nograd(data_list[2]) # forward generator enhanced = self.G(mixture) enhanced_D = enhanced.detach() # adversarial training: G-step ae_ny_G = self.D.forward_paired(enhanced, mixture) l_adv_ny_G, _ = self.diffLoss(ae_ny_G, enhanced, mask) # normalized inside function l_adv_ny_G = l_adv_ny_G * self.config.w_adversarial l_adv_ny_G_data = l_adv_ny_G.data[0] l_adv_ny_G.backward(retain_graph=True) g_adv = self.get_gradient_norm(self.G) self.D.zero_grad() # this makes no gradient for discriminator del l_adv_ny_G # adversarial training: D-step ae_ny_D = self.D.forward_paired(enhanced_D, mixture) l_adv_ny_D, _ = self.diffLoss(ae_ny_D, enhanced_D, mask) # normalized inside function l_adv_ny_D = l_adv_ny_D * (-self.kt) * self.config.w_adversarial l_adv_ny_D.backward() del l_adv_ny_D # DCE loss dce, nElement = self.diffLoss( enhanced, cleans, mask) # already normalized inside function dce_loss = dce.data[0] dce_tr_local.update(dce_loss, nElement) # Clean data ae_cl = self.D(cleans, mixture) l_adv_cl, _ = self.diffLoss(ae_cl, cleans, mask) # normalized inside function l_adv_cl = self.config.w_adversarial * l_adv_cl l_adv_cl.backward() l_adv_cl_data = l_adv_cl.data[0] del l_adv_cl # update optimizer_g.step() optimizer_d.step() # Proportional Control Theory g_d_balance = self.gamma * l_adv_cl_data - l_adv_ny_G_data self.kt += self.lb * g_d_balance self.kt = max(min(1, self.kt), 0) conv_measure = l_adv_cl_data + abs(g_d_balance) # log #pdb.set_trace() if (iter + 1) % self.config.log_iter == 0: str_loss = "[{}/{}] (train) DCE: {:.7f}, ADV_cl: {:.7f}, ADV_ny: {:.7f}".format( iter, self.config.max_iter, self.dce_tr_local.avg, l_adv_cl_data, l_adv_ny_G_data) print(str_loss) self.logFile.write(str_loss + '\n') str_loss = "[{}/{}] (train) conv_measure: {:.4f}, kt: {:.4f} ".format( iter, self.config.max_iter, conv_measure, self.kt) print(str_loss) self.logFile.write(str_loss + '\n') self.logFile.flush() self.dce_tr_local.reset() if (iter + 1) % self.config.save_iter == 0: self.G.eval() # Measure performance on training subset self.dce_tr.reset() self.adv_ny_tr.reset() self.wer_tr.reset() self.cer_tr.reset() for _ in trange(0, len(self.data_loader.trsub_dl)): data_list = self.data_loader.next(cl_ny='ny', type='trsub') mixture, cleans, mask, targets, input_percentages, target_sizes = \ data_list[0], data_list[1], _get_variable_volatile(data_list[2]), data_list[3], data_list[4], data_list[5] dce, adv_ny, nElement, wer, cer, nWord, nChar = self.greedy_decoding_and_FSEGAN( mixture, cleans, targets, input_percentages, target_sizes, mask) self.dce_tr.update(dce.data[0], nElement) self.adv_ny_tr.update(adv_ny.data[0], nElement) self.wer_tr.update(wer, nWord) self.cer_tr.update(cer, nChar) del dce, adv_ny str_loss = "[{}/{}] (training subset) CTC: {:.7f}, WER: {:.7f}, CER: {:.7f}".format( iter, self.config.max_iter, self.dce_tr.avg, self.wer_tr.avg * 100, self.cer_tr.avg * 100) print(str_loss) self.logFile.write(str_loss + '\n') # Measure performance on validation data self.dce_val.reset() self.adv_ny_val.reset() self.wer_val.reset() self.cer_val.reset() for _ in trange(0, len(self.data_loader.val_dl)): data_list = self.data_loader.next(cl_ny='ny', type='val') mixture, cleans, mask, targets, input_percentages, target_sizes = \ data_list[0], data_list[1], _get_variable_volatile(data_list[2]), data_list[3], data_list[4], data_list[5] dce, adv_ny, nElement, wer, cer, nWord, nChar = self.greedy_decoding_and_FSEGAN( mixture, cleans, targets, input_percentages, target_sizes, mask) self.dce_val.update(dce.data[0], nElement) self.adv_ny_val.update(adv_ny.data[0], nElement) self.wer_val.update(wer, nWord) self.cer_val.update(cer, nChar) del ctc, adv_ny str_loss = "[{}/{}] (validation) CTC: {:.7f}, WER: {:.7f}, CER: {:.7f}".format( iter, self.config.max_iter, self.dce_val.avg, self.wer_val.avg * 100, self.cer_val.avg * 100) print(str_loss) self.logFile.write(str_loss + '\n') self.logFile.flush() self.G.train() # end of validation # Save model if (len(self.savename_G) > 0): # do not remove here if os.path.exists(self.savename_G): os.remove(self.savename_G) # remove previous model self.savename_G = '{}/G_{}.pth'.format(self.model_dir, iter) torch.save(self.G.state_dict(), self.savename_G) if (self.G.loss_stop > self.wer_val.avg): self.G.loss_stop = self.wer_val.avg savename_G_valmin_prev = '{}/G_valmin_{}.pth'.format( self.model_dir, self.valmin_iter) if os.path.exists(savename_G_valmin_prev): os.remove( savename_G_valmin_prev) # remove previous model print('save model for this checkpoint') savename_G_valmin = '{}/G_valmin_{}.pth'.format( self.model_dir, iter) copyfile(self.savename_G, savename_G_valmin) self.valmin_iter = iter
len(train_sampler), batch_time=batch_time, data_time=data_time, loss=losses)) del loss del output start_iter = 0 # Reset start iteration for next epoch total_cer, total_wer = 0, 0 model.eval() losses.reset() for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader)): # load data input, target = data input = _get_variable_volatile(input, cuda=True) target = _get_variable_volatile(target, cuda=True) # Forward output = model(input) loss = criterion(output, target) loss = loss / input.size(0) # average the loss by minibatch losses.update(loss.data[0], input.size(0)) # TODO : measure accuracy valid_accuracy = 0 # TODO print('Validation Summary Epoch: [{0}]\t' 'Average loss {loss:.3f}\t' 'Average acc {acc:.3f}\t'.format(epoch + 1,
def train(self): # Setting optimizer_g = torch.optim.Adam(self.G.parameters(), lr=self.config.lr, betas=(self.beta1, self.beta2), amsgrad=True) optimizer_asr = torch.optim.Adam(self.ASR.parameters(), lr=self.config.lr, betas=(self.beta1, self.beta2), amsgrad=True) optimizer_d = torch.optim.Adam(self.D.parameters(), lr=self.config.lr, betas=(self.beta1, self.beta2), amsgrad=True) for iter in trange(self.config.start_iter, self.config.max_iter): self.zero_grad_all() # Train # Noisy data data_list = self.data_loader.next(cl_ny='ny', type='train') inputs, targets, input_percentages, target_sizes, mask = \ _get_variable_nograd(data_list[0]), _get_variable_nograd(data_list[1], cuda=False), data_list[2], _get_variable_nograd(data_list[3], cuda=False), _get_variable_nograd(data_list[4]) N = inputs.size(0) # forward generator enhanced = self.G(inputs) enhanced_D = enhanced.detach() # adversarial training: G-step ae_ny_G = self.D(enhanced) l_adv_ny_G, _ = self.diffLoss(ae_ny_G, enhanced, mask) # normalized inside function l_adv_ny_G = l_adv_ny_G * self.config.w_adversarial l_adv_ny_G_data = l_adv_ny_G.data[0] l_adv_ny_G.backward(retain_graph=True) g_adv = self.get_gradient_norm(self.G) self.D.zero_grad() # this makes no gradient for discriminator del l_adv_ny_G # adversarial training: D-step ae_ny_D = self.D(enhanced_D) l_adv_ny_D, _ = self.diffLoss(ae_ny_D, enhanced_D, mask) # normalized inside function l_adv_ny_D = l_adv_ny_D * (-self.kt) * self.config.w_adversarial #l_adv_ny_D_data = l_adv_ny_D.data[0] l_adv_ny_D.backward() del l_adv_ny_D # CTC loss prob = self.ASR(enhanced) prob = prob.transpose(0, 1) T = prob.size(0) sizes = _get_variable_nograd(input_percentages.mul_(int(T)).int(), cuda=False) l_CTC = self.config.w_acoustic * self.CTCLoss( prob, targets, sizes, target_sizes) / N self.ctc_tr_local.update(l_CTC.data[0], N) l_CTC.backward() g_ctc_adv = self.get_gradient_norm(self.G) del l_CTC # Clean data data_list_cl = self.data_loader.next(cl_ny='cl', type='train') inputs, mask = _get_variable_nograd( data_list_cl[0]), _get_variable_nograd(data_list_cl[4]) ae_cl = self.D(inputs) l_adv_cl, _ = self.diffLoss(ae_cl, inputs, mask) # normalized inside function l_adv_cl = self.config.w_adversarial * l_adv_cl l_adv_cl.backward() l_adv_cl_data = l_adv_cl.data[0] del l_adv_cl # update optimizer_g.step() optimizer_d.step() if (iter > self.config.allow_ASR_update_iter): optimizer_asr.step() # Proportional Control Theory g_d_balance = self.gamma * l_adv_cl_data - l_adv_ny_G_data self.kt += self.lb * g_d_balance self.kt = max(min(1, self.kt), 0) conv_measure = l_adv_cl_data + abs(g_d_balance) # log #pdb.set_trace() if (iter + 1) % self.config.log_iter == 0: str_loss = "[{}/{}] (train) CTC: {:.7f}, ADV_cl: {:.7f}, ADV_ny: {:.7f}".format( iter, self.config.max_iter, self.ctc_tr_local.avg, l_adv_cl_data, l_adv_ny_G_data) print(str_loss) self.logFile.write(str_loss + '\n') str_loss = "[{}/{}] (train) conv_measure: {:.4f}, kt: {:.4f} ".format( iter, self.config.max_iter, conv_measure, self.kt) print(str_loss) self.logFile.write(str_loss + '\n') str_loss = "[{}/{}] (train) gradient norm, adv: {:.4f}, adv + ctc : {:.4f}".format( iter, self.config.max_iter, g_adv.data[0], g_ctc_adv.data[0]) print(str_loss) self.logFile.write(str_loss + '\n') self.logFile.flush() self.ctc_tr_local.reset() if (iter + 1) % self.config.save_iter == 0: self.G.eval() # Measure performance on training subset self.ctc_tr.reset() self.adv_ny_tr.reset() self.wer_tr.reset() self.cer_tr.reset() for _ in trange(0, len(self.data_loader.trsub_dl)): data_list = self.data_loader.next(cl_ny='ny', type='trsub') inputs, targets, input_percentages, target_sizes, mask = data_list[ 0], data_list[1], data_list[2], data_list[ 3], _get_variable_volatile(data_list[4]) ctc, adv_ny, nElement, wer, cer, nWord, nChar = self.greedy_decoding_and_AAS( inputs, targets, input_percentages, target_sizes, mask) N = inputs.size(0) self.ctc_tr.update(ctc.data[0], N) self.adv_ny_tr.update(adv_ny.data[0], nElement) self.wer_tr.update(wer, nWord) self.cer_tr.update(cer, nChar) del ctc, adv_ny str_loss = "[{}/{}] (training subset) CTC: {:.7f}, WER: {:.7f}, CER: {:.7f}".format( iter, self.config.max_iter, self.ctc_tr.avg, self.wer_tr.avg * 100, self.cer_tr.avg * 100) print(str_loss) self.logFile.write(str_loss + '\n') # Measure performance on validation data self.ctc_val.reset() self.adv_ny_val.reset() self.wer_val.reset() self.cer_val.reset() for _ in trange(0, len(self.data_loader.val_dl)): data_list = self.data_loader.next(cl_ny='ny', type='val') inputs, targets, input_percentages, target_sizes, mask = data_list[ 0], data_list[1], data_list[2], data_list[ 3], _get_variable_volatile(data_list[4]) ctc, adv_ny, nElement, wer, cer, nWord, nChar = self.greedy_decoding_and_AAS( inputs, targets, input_percentages, target_sizes, mask) N = inputs.size(0) self.ctc_val.update(ctc.data[0], N) self.adv_ny_val.update(adv_ny.data[0], nElement) self.wer_val.update(wer, nWord) self.cer_val.update(cer, nChar) del ctc, adv_ny str_loss = "[{}/{}] (validation) CTC: {:.7f}, WER: {:.7f}, CER: {:.7f}".format( iter, self.config.max_iter, self.ctc_val.avg, self.wer_val.avg * 100, self.cer_val.avg * 100) print(str_loss) self.logFile.write(str_loss + '\n') self.logFile.flush() self.G.train() # end of validation # Save model if (len(self.savename_G) > 0): # do not remove here if os.path.exists(self.savename_G): os.remove(self.savename_G) # remove previous model self.savename_G = '{}/G_{}.pth'.format(self.model_dir, iter) torch.save(self.G.state_dict(), self.savename_G) if (len(self.savename_ASR) > 0): if os.path.exists(self.savename_ASR): os.remove(self.savename_ASR) self.savename_ASR = '{}/ASR_{}.pth'.format( self.model_dir, iter) torch.save(self.ASR.state_dict(), self.savename_ASR) if (self.G.loss_stop > self.wer_val.avg): self.G.loss_stop = self.wer_val.avg savename_G_valmin_prev = '{}/G_valmin_{}.pth'.format( self.model_dir, self.valmin_iter) if os.path.exists(savename_G_valmin_prev): os.remove( savename_G_valmin_prev) # remove previous model print('save model for this checkpoint') savename_G_valmin = '{}/G_valmin_{}.pth'.format( self.model_dir, iter) copyfile(self.savename_G, savename_G_valmin) savename_ASR_valmin_prev = '{}/ASR_valmin_{}.pth'.format( self.model_dir, self.valmin_iter) if os.path.exists(savename_ASR_valmin_prev): os.remove( savename_ASR_valmin_prev) # remove previous model print('save model for this checkpoint') savename_ASR_valmin = '{}/ASR_valmin_{}.pth'.format( self.model_dir, iter) copyfile(self.savename_ASR, savename_ASR_valmin) self.valmin_iter = iter