def eval(epoch): # config.batch_size=1 model.eval() print '\n\n测试的时候请设置config里的batch_size为1!!!please set the batch_size as 1' reference, candidate, source, alignments = [], [], [], [] e = epoch test_or_valid = 'test' print('Test or valid:', test_or_valid) eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX, config.MAX_MIX) SDR_SUM = np.array([]) SDRi_SUM = np.array([]) batch_idx = 0 global best_SDR, Var while True: print('-' * 30) eval_data = eval_data_gen.next() if eval_data == False: print('SDR_aver_eval_epoch:', SDR_SUM.mean()) print('SDRi_aver_eval_epoch:', SDRi_SUM.mean()) break # 如果这个epoch的生成器没有数据了,直接进入下一个epoch src = Variable(torch.from_numpy(eval_data['mix_feas'])) raw_tgt = [ sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list'] ] feas_tgt = models.rank_feas( raw_tgt, eval_data['multi_spk_fea_list']) # 这里是目标的图谱 top_k = len(raw_tgt[0]) # 要保证底下这几个都是longTensor(长整数) # tgt = Variable(torch.from_numpy(np.array([[0]+[dict_spk2idx[spk] for spk in spks]+[dict_spk2idx['<EOS>']] for spks in raw_tgt],dtype=np.int))).transpose(0,1) #转换成数字,然后前后加开始和结束符号。 tgt = Variable(torch.ones( top_k + 2, config.batch_size)) # 这里随便给一个tgt,为了测试阶段tgt的名字无所谓其实。 src_len = Variable( torch.LongTensor(config.batch_size).zero_() + mix_speech_len).unsqueeze(0) tgt_len = Variable( torch.LongTensor([ len(one_spk) for one_spk in eval_data['multi_spk_fea_list'] ])).unsqueeze(0) # tgt_len = Variable(torch.LongTensor(config.batch_size).zero_()+len(eval_data['multi_spk_fea_list'][0])).unsqueeze(0) if config.WFM: tmp_size = feas_tgt.size() assert len(tmp_size) == 4 feas_tgt_sum = torch.sum(feas_tgt, dim=1, keepdim=True) feas_tgt_sum_square = (feas_tgt_sum * feas_tgt_sum).expand(tmp_size) feas_tgt_square = feas_tgt * feas_tgt WFM_mask = feas_tgt_square / feas_tgt_sum_square if use_cuda: src = src.cuda().transpose(0, 1) tgt = tgt.cuda() src_len = src_len.cuda() tgt_len = tgt_len.cuda() feas_tgt = feas_tgt.cuda() if config.WFM: WFM_mask = WFM_mask.cuda() if 1 and len(opt.gpus) > 1: samples, alignment, hiddens, predicted_masks = model.module.beam_sample( src, src_len, dict_spk2idx, tgt, beam_size=config.beam_size) else: samples, alignment, hiddens, predicted_masks = model.beam_sample( src, src_len, dict_spk2idx, tgt, beam_size=config.beam_size) # ''' # expand the raw mixed-features to topk_max channel. src = src.transpose(0, 1) siz = src.size() assert len(siz) == 3 # if samples[0][-1] != dict_spk2idx['<EOS>']: # print '*'*40+'\nThe model is far from good. End the evaluation.\n'+'*'*40 # break topk_max = len(samples[0]) - 1 x_input_map_multi = torch.unsqueeze(src, 1).expand(siz[0], topk_max, siz[1], siz[2]) if config.WFM: feas_tgt = x_input_map_multi.data * WFM_mask if 0 and test_or_valid == 'valid': if 1 and len(opt.gpus) > 1: ss_loss = model.module.separation_loss( x_input_map_multi, predicted_masks, feas_tgt, ) else: ss_loss = model.separation_loss(x_input_map_multi, predicted_masks, feas_tgt) print('loss for ss,this batch:', ss_loss.cpu().item()) lera.log({ 'ss_loss_' + test_or_valid: ss_loss.cpu().item(), }) del ss_loss, hiddens # ''''' if batch_idx <= (500 / config.batch_size ): # only the former batches counts the SDR predicted_maps = predicted_masks * x_input_map_multi # predicted_maps=Variable(feas_tgt) utils.bss_eval2(config, predicted_maps, eval_data['multi_spk_fea_list'], raw_tgt, eval_data, dst='batch_output') del predicted_maps, predicted_masks, x_input_map_multi try: SDR_SUM, SDRi_SUM = np.append(SDR_SUM, bss_test.cal('batch_output/')) except AssertionError, wrong_info: print 'Errors in calculating the SDR', wrong_info print('SDR_aver_now:', SDR_SUM.mean()) print('SDRi_aver_now:', SDRi_SUM.mean()) lera.log({'SDR sample' + test_or_valid: SDR_SUM.mean()}) lera.log({'SDRi sample' + test_or_valid: SDRi_SUM.mean()}) # raw_input('Press any key to continue......') elif batch_idx == (500 / config.batch_size) + 1 and SDR_SUM.mean( ) > best_SDR: # only record the best SDR once. print('Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean())) best_SDR = SDR_SUM.mean()
def train(epoch): e = epoch model.train() SDR_SUM = np.array([]) SDRi_SUM = np.array([]) if config.schedule and scheduler.get_lr()[0] > 5e-5: scheduler.step() print("Decaying learning rate to %g" % scheduler.get_lr()[0]) lera.log({ 'lr': scheduler.get_lr()[0], }) if opt.model == 'gated': model.current_epoch = epoch global e, updates, total_loss, start_time, report_total, report_correct, total_loss_sgm, total_loss_ss train_data_gen = prepare_data('once', 'train') while True: print '\n' train_data = train_data_gen.next() if train_data == False: print('SDR_aver_epoch:', SDR_SUM.mean()) print('SDRi_aver_epoch:', SDRi_SUM.mean()) break # 如果这个epoch的生成器没有数据了,直接进入下一个epoch src = Variable(torch.from_numpy(train_data['mix_feas'])) # raw_tgt = [spk.keys() for spk in train_data['multi_spk_fea_list']] raw_tgt = [ sorted(spk.keys()) for spk in train_data['multi_spk_fea_list'] ] feas_tgt = models.rank_feas( raw_tgt, train_data['multi_spk_fea_list']) # 这里是目标的图谱,aim_size,len,fre # 要保证底下这几个都是longTensor(长整数) tgt_max_len = config.MAX_MIX + 2 # with bos and eos. tgt = Variable( torch.from_numpy( np.array( [[0] + [dict_spk2idx[spk] for spk in spks] + (tgt_max_len - len(spks) - 1) * [dict_spk2idx['<EOS>']] for spks in raw_tgt], dtype=np.int))).transpose(0, 1) # 转换成数字,然后前后加开始和结束符号。 src_len = Variable( torch.LongTensor(config.batch_size).zero_() + mix_speech_len).unsqueeze(0) tgt_len = Variable( torch.LongTensor([ len(one_spk) for one_spk in train_data['multi_spk_fea_list'] ])).unsqueeze(0) if use_cuda: src = src.cuda().transpose(0, 1) tgt = tgt.cuda() src_len = src_len.cuda() tgt_len = tgt_len.cuda() feas_tgt = feas_tgt.cuda() model.zero_grad() # aim_list 就是找到有正经说话人的地方的标号 aim_list = (tgt[1:-1].transpose(0, 1).contiguous().view(-1) != dict_spk2idx['<EOS>']).nonzero().squeeze() aim_list = aim_list.data.cpu().numpy() outputs, targets, multi_mask, gamma = model( src, src_len, tgt, tgt_len, dict_spk2idx) # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用 # print('mask size:', multi_mask.size()) writer.add_histogram('global gamma', gamma, updates) if 1 and len(opt.gpus) > 1: sgm_loss, num_total, num_correct = model.module.compute_loss( outputs, targets, opt.memory) else: sgm_loss, num_total, num_correct = model.compute_loss( outputs, targets, opt.memory) print('loss for SGM,this batch:', sgm_loss.cpu().item()) writer.add_scalars('scalar/loss', {'sgm_loss': sgm_loss.cpu().item()}, updates) src = src.transpose(0, 1) # expand the raw mixed-features to topk_max channel. siz = src.size() assert len(siz) == 3 topk_max = config.MAX_MIX # 最多可能的topk个数 x_input_map_multi = torch.unsqueeze(src, 1).expand( siz[0], topk_max, siz[1], siz[2]).contiguous().view(-1, siz[1], siz[2]) x_input_map_multi = x_input_map_multi[aim_list] multi_mask = multi_mask.transpose(0, 1) if 1 and len(opt.gpus) > 1: ss_loss = model.module.separation_loss(x_input_map_multi, multi_mask, feas_tgt) else: ss_loss = model.separation_loss(x_input_map_multi, multi_mask, feas_tgt) print('loss for SS,this batch:', ss_loss.cpu().item()) writer.add_scalars('scalar/loss', {'ss_loss': ss_loss.cpu().item()}, updates) loss = sgm_loss + 5 * ss_loss loss.backward() # print 'totallllllllllll loss:',loss total_loss_sgm += sgm_loss.cpu().item() total_loss_ss += ss_loss.cpu().item() lera.log({ 'sgm_loss': sgm_loss.cpu().item(), 'ss_loss': ss_loss.cpu().item(), 'loss:': loss.cpu().item(), }) if updates > 10 and updates % config.eval_interval in [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ]: predicted_maps = multi_mask * x_input_map_multi # predicted_maps=Variable(feas_tgt) utils.bss_eval(config, predicted_maps, train_data['multi_spk_fea_list'], raw_tgt, train_data, dst='batch_output') del predicted_maps, multi_mask, x_input_map_multi sdr_aver_batch, sdri_aver_batch = bss_test.cal('batch_output/') lera.log({'SDR sample': sdr_aver_batch}) lera.log({'SDRi sample': sdri_aver_batch}) writer.add_scalars('scalar/loss', { 'SDR_sample': sdr_aver_batch, 'SDRi_sample': sdri_aver_batch }, updates) SDR_SUM = np.append(SDR_SUM, sdr_aver_batch) SDRi_SUM = np.append(SDRi_SUM, sdri_aver_batch) print('SDR_aver_now:', SDR_SUM.mean()) print('SDRi_aver_now:', SDRi_SUM.mean()) total_loss += loss.cpu().item() report_correct += num_correct.cpu().item() report_total += num_total.cpu().item() optim.step() updates += 1 if updates % 30 == 0: logging( "time: %6.3f, epoch: %3d, updates: %8d, train loss this batch: %6.3f,sgm loss: %6.6f,ss loss: %6.6f,label acc: %6.6f\n" % (time.time() - start_time, epoch, updates, loss / num_total, total_loss_sgm / 30.0, total_loss_ss / 30.0, report_correct / report_total)) lera.log({'label_acc': report_correct / report_total}) writer.add_scalars('scalar/loss', {'label_acc': report_correct / report_total}, updates) total_loss_sgm, total_loss_ss = 0, 0 # continue if 0 and updates % config.eval_interval == 0 and epoch > 3: #建议至少跑几个epoch再进行测试,否则模型还没学到东西,会有很多问题。 logging( "time: %6.3f, epoch: %3d, updates: %8d, train loss: %6.5f\n" % (time.time() - start_time, epoch, updates, total_loss / report_total)) print('evaluating after %d updates...\r' % updates) original_bs = config.batch_size score = eval(epoch) # eval的时候batch_size会变成1 print 'Orignal bs:', original_bs config.batch_size = original_bs print 'Now bs:', config.batch_size for metric in config.metric: scores[metric].append(score[metric]) lera.log({ 'sgm_micro_f1': score[metric], }) if metric == 'micro_f1' and score[metric] >= max( scores[metric]): save_model(log_path + 'best_' + metric + '_checkpoint.pt') if metric == 'hamming_loss' and score[metric] <= min( scores[metric]): save_model(log_path + 'best_' + metric + '_checkpoint.pt') model.train() total_loss = 0 start_time = 0 report_total = 0 report_correct = 0 if updates % config.save_interval == 1: save_model(log_path + 'TDAAv3_{}.pt'.format(updates))
def train(epoch): e = epoch model.train() SDR_SUM = np.array([]) if config.schedule: scheduler.step() print("Decaying learning rate to %g" % scheduler.get_lr()[0]) if config.is_dis: scheduler_dis.step() lera.log({ 'lr': scheduler.get_lr()[0], }) if opt.model == 'gated': model.current_epoch = epoch global e, updates, total_loss, start_time, report_total, total_loss_sgm, total_loss_ss if config.MLMSE: global Var train_data_gen = prepare_data('once', 'train') # for raw_src, src, src_len, raw_tgt, tgt, tgt_len in trainloader: while True: try: train_data = train_data_gen.next() if train_data == False: print 'SDR_aver_epoch:', SDR_SUM.mean() break #如果这个epoch的生成器没有数据了,直接进入下一个epoch src = Variable(torch.from_numpy(train_data['mix_feas'])) # raw_tgt = [spk.keys() for spk in train_data['multi_spk_fea_list']] raw_tgt = [ sorted(spk.keys()) for spk in train_data['multi_spk_fea_list'] ] feas_tgt = models.rank_feas( raw_tgt, train_data['multi_spk_fea_list']) #这里是目标的图谱,aim_size,len,fre # 要保证底下这几个都是longTensor(长整数) tgt_max_len = config.MAX_MIX + 2 # with bos and eos. tgt = Variable( torch.from_numpy( np.array([[0] + [dict_spk2idx[spk] for spk in spks] + (tgt_max_len - len(spks) - 1) * [dict_spk2idx['<EOS>']] for spks in raw_tgt], dtype=np.int))).transpose(0, 1) #转换成数字,然后前后加开始和结束符号。 src_len = Variable( torch.LongTensor(config.batch_size).zero_() + mix_speech_len).unsqueeze(0) tgt_len = Variable( torch.LongTensor([ len(one_spk) for one_spk in train_data['multi_spk_fea_list'] ])).unsqueeze(0) if use_cuda: src = src.cuda().transpose(0, 1) tgt = tgt.cuda() src_len = src_len.cuda() tgt_len = tgt_len.cuda() feas_tgt = feas_tgt.cuda() model.zero_grad() # optim.optimizer.zero_grad() # aim_list 就是找到有正经说话人的地方的标号 aim_list = (tgt[1:-1].transpose(0, 1).contiguous().view(-1) != dict_spk2idx['<EOS>']).nonzero().squeeze() aim_list = aim_list.data.cpu().numpy() outputs, targets, multi_mask = model( src, src_len, tgt, tgt_len, dict_spk2idx) #这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用 print 'mask size:', multi_mask.size() if 1 and len(opt.gpus) > 1: sgm_loss, num_total, num_correct = model.module.compute_loss( outputs, targets, opt.memory) else: sgm_loss, num_total, num_correct = model.compute_loss( outputs, targets, opt.memory) print 'loss for SGM,this batch:', sgm_loss.data[0] / num_total src = src.transpose(0, 1) # expand the raw mixed-features to topk_max channel. siz = src.size() assert len(siz) == 3 topk_max = config.MAX_MIX #最多可能的topk个数 x_input_map_multi = torch.unsqueeze(src, 1).expand( siz[0], topk_max, siz[1], siz[2]).contiguous().view(-1, siz[1], siz[2]) x_input_map_multi = x_input_map_multi[aim_list] multi_mask = multi_mask.transpose(0, 1) if 1 and len(opt.gpus) > 1: if config.MLMSE: Var = model.module.update_var(x_input_map_multi, multi_mask, feas_tgt) lera.log_image(u'Var weight', Var.data.cpu().numpy().reshape( config.speech_fre, config.speech_fre, 1).repeat(3, 2), clip=(-1, 1)) ss_loss = model.module.separation_loss( x_input_map_multi, multi_mask, feas_tgt, Var) else: ss_loss = model.module.separation_loss( x_input_map_multi, multi_mask, feas_tgt) else: ss_loss = model.separation_loss(x_input_map_multi, multi_mask, feas_tgt) loss = sgm_loss + 5 * ss_loss # dis_loss model if config.is_dis: dis_loss = models.loss.dis_loss(config, topk_max, model_dis, x_input_map_multi, multi_mask, feas_tgt, func_dis) loss = loss + dis_loss # print 'dis_para',model_dis.parameters().next()[0] # print 'ss_para',model.parameters().next()[0] loss.backward() # print 'totallllllllllll loss:',loss total_loss_sgm += sgm_loss.data[0] total_loss_ss += ss_loss.data[0] lera.log({ 'sgm_loss': sgm_loss.data[0], 'ss_loss': ss_loss.data[0], 'loss:': loss.data[0], }) if (updates % config.eval_interval) in [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ]: predicted_maps = multi_mask * x_input_map_multi # predicted_maps=Variable(feas_tgt) utils.bss_eval(config, predicted_maps, train_data['multi_spk_fea_list'], raw_tgt, train_data, dst='batch_outputjaa') del predicted_maps, multi_mask, x_input_map_multi # raw_input('wait to continue......') sdr_aver_batch = bss_test.cal('batch_outputjaa/') lera.log({'SDR sample': sdr_aver_batch}) SDR_SUM = np.append(SDR_SUM, sdr_aver_batch) print 'SDR_aver_now:', SDR_SUM.mean() total_loss += loss.data[0] report_total += num_total optim.step() if config.is_dis: optim_dis.step() updates += 1 if updates % 30 == 0: logging( "time: %6.3f, epoch: %3d, updates: %8d, train loss this batch: %6.3f,sgm loss: %6.6f,ss loss: %6.6f\n" % (time.time() - start_time, epoch, updates, loss / num_total, total_loss_sgm / 30.0, total_loss_ss / 30.0)) total_loss_sgm, total_loss_ss = 0, 0 # continue if 0 or updates % config.eval_interval == 0 and epoch > 1: logging( "time: %6.3f, epoch: %3d, updates: %8d, train loss: %6.5f\n" % (time.time() - start_time, epoch, updates, total_loss / report_total)) print('evaluating after %d updates...\r' % updates) # score = eval(epoch) for metric in config.metric: scores[metric].append(score[metric]) lera.log({ 'sgm_micro_f1': score[metric], }) if metric == 'micro_f1' and score[metric] >= max( scores[metric]): save_model(log_path + 'best_' + metric + '_checkpoint.pt') if metric == 'hamming_loss' and score[metric] <= min( scores[metric]): save_model(log_path + 'best_' + metric + '_checkpoint.pt') model.train() total_loss = 0 start_time = 0 report_total = 0 except RuntimeError, eeee: print 'Erros here eeee: ', eeee continue except Exception, dddd: print '\n\n\nRare errors: ', dddd continue
print('loading checkpoint...\n', opt.restore) checkpoints = torch.load(opt.restore, map_location={'cuda:2': 'cuda:0'}) # cuda use_cuda = torch.cuda.is_available() and len(opt.gpus) > 0 use_cuda = True if use_cuda: torch.cuda.set_device(opt.gpus[0]) torch.cuda.manual_seed(opt.seed) print(use_cuda) # load the global statistic of the data print('loading data...\n') start_time = time.time() spk_global_gen = prepare_data(mode='global', train_or_test='train') # 数据中的一些统计参数的读取 global_para = spk_global_gen.next() print(global_para) spk_all_list = global_para['all_spk'] # 所有说话人的列表 dict_spk2idx = global_para['dict_spk_to_idx'] dict_idx2spk = global_para['dict_idx_to_spk'] speech_fre = global_para['num_fre'] # 语音频率总数 total_frames = global_para['num_frames'] # 语音长度 spk_num_total = global_para['total_spk_num'] # 总计说话人数目 batch_total = global_para['total_batch_num'] # 一个epoch里多少个batch config.speech_fre = speech_fre mix_speech_len = total_frames config.mix_speech_len = total_frames num_labels = len(spk_all_list)
print('loading checkpoint...\n', opt.restore) checkpoints = torch.load(opt.restore) # cuda use_cuda = torch.cuda.is_available() and len(opt.gpus) > 0 use_cuda = True if use_cuda: torch.cuda.set_device(opt.gpus[0]) torch.cuda.manual_seed(opt.seed) print(use_cuda) # data print('loading data...\n') start_time = time.time() spk_global_gen = prepare_data(mode='global', train_or_test='train') #写一个假的数据生成,可以用来写模型先 global_para = spk_global_gen.next() print global_para spk_all_list = global_para['all_spk'] dict_spk2idx = global_para['dict_spk_to_idx'] dict_idx2spk = global_para['dict_idx_to_spk'] speech_fre = global_para['num_fre'] # 语音频率 total_frames = global_para['num_frames'] # 语音长度 spk_num_total = global_para['total_spk_num'] batch_total = global_para['total_batch_num'] config.speech_fre = speech_fre mix_speech_len = total_frames # mix_speech_len=626 # mix_speech_len=1251 config.mix_speech_len = total_frames