def eval_model(mix_speech_class,dict_spk2idx,loss_func): mix_speech_class.training=False print '#' * 40 eval_data_gen=prepare_data('once','valid',2,2) acc_all,acc_line=[],[] recall_rate_list=np.array([]) while True: eval_data=eval_data_gen.next() if eval_data==False: break #如果这个epoch的生成器没有数据了,直接进入下一个epoch mix_speech=mix_speech_class(Variable(torch.from_numpy(eval_data['mix_feas'])).cuda()) y_spk,y_map=multi_label_vector(eval_data['multi_spk_fea_list'],dict_spk2idx) y_map=Variable(torch.from_numpy(y_map)).cuda() y_out_batch=mix_speech.data.cpu().numpy() acc1,acc2,all_num_batch,all_line_batch,recall_rate=count_multi_acc(y_out_batch,y_spk,alpha=-0.1,top_k_num=2) acc_all.append(acc1) acc_line.append(acc2) recall_rate_list=np.append(recall_rate_list,recall_rate) for i in range(config.BATCH_SIZE): print 'aim:{}-->{},predict:{}'.format(eval_data['multi_spk_fea_list'][i].keys(),y_spk[i],mix_speech.data.cpu().numpy()[i][y_spk[i]])#除了输出目标的几个概率,也输出倒数四个的 print 'last 4 probility:{}'.format(mix_speech.data.cpu().numpy()[i][-5:])#除了输出目标的几个概率,也输出倒数四个的 print '\nAcc for this batch: all elements({}) acc--{},all sample({}) acc--{} recall--{}'.format(all_num_batch,acc1,all_line_batch,acc2,recall_rate) loss=loss_func(mix_speech,y_map) loss_sum=loss_func(mix_speech.sum(1),y_map.sum(1)) lrs.send('eval_loss',loss.data[0]) print 'loss this batch:',loss.data.cpu().numpy(),loss_sum.data.cpu().numpy() print 'time:',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) lrs.send('eval_sum_loss',loss_sum.data[0]) print 'Acc for eval dataset: all elements acc--{},all sample acc--{}, recall_rate--{}'.format(np.mean(acc_all),np.mean(acc_line),np.mean(recall_rate_list)) lrs.send('eval_recall_rate_aver',np.mean(recall_rate_list)) print '#'*40
def main(config): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_transform = transforms.Compose([ transforms.Scale(256), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor()]) val_transform = transforms.Compose([ transforms.Scale(256), transforms.RandomCrop(224), transforms.ToTensor()]) trainset = AVADataset(csv_file=config.train_csv_file, root_dir=config.train_img_path, transform=train_transform) valset = AVADataset(csv_file=config.val_csv_file, root_dir=config.val_img_path, transform=val_transform) train_loader = torch.utils.data.DataLoader(trainset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers) val_loader = torch.utils.data.DataLoader(valset, batch_size=config.val_batch_size, shuffle=False, num_workers=config.num_workers) base_model = models.vgg16(pretrained=True) model = NIMA(base_model) if config.warm_start: model.load_state_dict(torch.load(os.path.join(config.ckpt_path, 'epoch-%d.pkl' % config.warm_start_epoch))) print('Successfully loaded model epoch-%d.pkl' % config.warm_start_epoch) if config.multi_gpu: model.features = torch.nn.DataParallel(model.features, device_ids=config.gpu_ids) model = model.to(device) else: model = model.to(device) conv_base_lr = config.conv_base_lr dense_lr = config.dense_lr optimizer = optim.SGD([ {'params': model.features.parameters(), 'lr': conv_base_lr}, {'params': model.classifier.parameters(), 'lr': dense_lr}], momentum=0.9 ) # send hyperparams lrs.send({ 'title': 'EMD Loss', 'train_batch_size': config.train_batch_size, 'val_batch_size': config.val_batch_size, 'optimizer': 'SGD', 'conv_base_lr': config.conv_base_lr, 'dense_lr': config.dense_lr, 'momentum': 0.9 }) param_num = 0 for param in model.parameters(): param_num += int(np.prod(param.shape)) print('Trainable params: %.2f million' % (param_num / 1e6)) if config.train: # for early stopping count = 0 init_val_loss = float('inf') train_losses = [] val_losses = [] for epoch in range(config.warm_start_epoch, config.epochs): lrs.send('epoch', epoch) batch_losses = [] for i, data in enumerate(train_loader): images = data['image'].to(device) labels = data['annotations'].to(device).float() outputs = model(images) outputs = outputs.view(-1, 10, 1) optimizer.zero_grad() loss = emd_loss(labels, outputs) batch_losses.append(loss.item()) loss.backward() optimizer.step() lrs.send('train_emd_loss', loss.item()) print('Epoch: %d/%d | Step: %d/%d | Training EMD loss: %.4f' % (epoch + 1, config.epochs, i + 1, len(trainset) // config.train_batch_size + 1, loss.data[0])) avg_loss = sum(batch_losses) / (len(trainset) // config.train_batch_size + 1) train_losses.append(avg_loss) print('Epoch %d averaged training EMD loss: %.4f' % (epoch + 1, avg_loss)) # exponetial learning rate decay if (epoch + 1) % 10 == 0: conv_base_lr = conv_base_lr * config.lr_decay_rate ** ((epoch + 1) / config.lr_decay_freq) dense_lr = dense_lr * config.lr_decay_rate ** ((epoch + 1) / config.lr_decay_freq) optimizer = optim.SGD([ {'params': model.features.parameters(), 'lr': conv_base_lr}, {'params': model.classifier.parameters(), 'lr': dense_lr}], momentum=0.9 ) # send decay hyperparams lrs.send({ 'lr_decay_rate': config.lr_decay_rate, 'lr_decay_freq': config.lr_decay_freq, 'conv_base_lr': config.conv_base_lr, 'dense_lr': config.dense_lr }) # do validation after each epoch batch_val_losses = [] for data in val_loader: images = data['image'].to(device) labels = data['annotations'].to(device).float() with torch.no_grad(): outputs = model(images) outputs = outputs.view(-1, 10, 1) val_loss = emd_loss(labels, outputs) batch_val_losses.append(val_loss.item()) avg_val_loss = sum(batch_val_losses) / (len(valset) // config.val_batch_size + 1) val_losses.append(avg_val_loss) lrs.send('val_emd_loss', avg_val_loss) print('Epoch %d completed. Averaged EMD loss on val set: %.4f.' % (epoch + 1, avg_val_loss)) # Use early stopping to monitor training if avg_val_loss < init_val_loss: init_val_loss = avg_val_loss # save model weights if val loss decreases print('Saving model...') torch.save(model.state_dict(), os.path.join(config.ckpt_path, 'epoch-%d.pkl' % (epoch + 1))) print('Done.\n') # reset count count = 0 elif avg_val_loss >= init_val_loss: count += 1 if count == config.early_stopping_patience: print('Val EMD loss has not decreased in %d epochs. Training terminated.' % config.early_stopping_patience) break print('Training completed.') if config.save_fig: # plot train and val loss epochs = range(1, epoch + 2) plt.plot(epochs, train_losses, 'b-', label='train loss') plt.plot(epochs, val_losses, 'g-', label='val loss') plt.title('EMD loss') plt.legend() plt.savefig('./loss.png') if config.test: # compute mean score test_transform = val_transform testset = AVADataset(csv_file=config.test_csv_file, root_dir=config.test_img_path, transform=val_transform) test_loader = torch.utils.data.DataLoader(testset, batch_size=config.test_batch_size, shuffle=False, num_workers=config.num_workers) mean_preds = [] std_preds = [] for data in test_loader: image = data['image'].to(device) output = model(image) output = output.view(10, 1) predicted_mean, predicted_std = 0.0, 0.0 for i, elem in enumerate(output, 1): predicted_mean += i * elem for j, elem in enumerate(output, 1): predicted_std += elem * (i - predicted_mean) ** 2 mean_preds.append(predicted_mean) std_preds.append(predicted_std)
def main(): print('go to model') print '*' * 80 spk_global_gen=prepare_data(mode='global',train_or_test='train') #写一个假的数据生成,可以用来写模型先 global_para=spk_global_gen.next() print global_para spk_all_list,dict_spk2idx,dict_idx2spk,mix_speech_len,speech_fre,total_frames,spk_num_total,batch_total=global_para del spk_global_gen num_labels=len(spk_all_list) print 'Begin to build the maim model for Multi_Modal Cocktail Problem.' # This part is to build the 3D mix speech embedding maps. mix_hidden_layer_3d=MIX_SPEECH(speech_fre,mix_speech_len).cuda() mix_speech_classifier=MIX_SPEECH_classifier(speech_fre,mix_speech_len,num_labels).cuda() mix_speech_multiEmbedding=SPEECH_EMBEDDING(num_labels,config.EMBEDDING_SIZE,spk_num_total+config.UNK_SPK_SUPP).cuda() print mix_hidden_layer_3d print mix_speech_classifier print mix_speech_multiEmbedding att_layer=ATTENTION(config.EMBEDDING_SIZE,'dot').cuda() att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'dot').cuda() adjust_layer=ADDJUST(2*config.HIDDEN_UNITS,config.EMBEDDING_SIZE) dis_layer=Discriminator().cuda() print att_speech_layer print att_speech_layer.mode print adjust_layer print dis_layer lr_data=0.0002 optimizer = torch.optim.Adam([{'params':mix_hidden_layer_3d.parameters()}, {'params':mix_speech_multiEmbedding.parameters()}, {'params':mix_speech_classifier.parameters()}, {'params':adjust_layer.parameters()}, {'params':att_speech_layer.parameters()}, {'params':dis_layer.parameters()}, ], lr=lr_data) if 0 and config.Load_param: mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_WSJ0_hidden3d_180')) mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_WSJ0_emblayer_180')) att_speech_layer.load_state_dict(torch.load('params/param_mix101_WSJ0_attlayer_180')) adjust_layer.load_state_dict(torch.load('params/param_mix101_WSJ0_attlayer_180')) loss_func = torch.nn.MSELoss() # the target label is NOT an one-hotted loss_multi_func = torch.nn.MSELoss() # the target label is NOT an one-hotted # loss_multi_func = torch.nn.L1Loss() # the target label is NOT an one-hotted loss_dis_class=torch.nn.MSELoss() lrs.send({ 'title': 'TDAA classifier', 'batch_size':config.BATCH_SIZE, 'batch_total':batch_total, 'epoch_size':config.EPOCH_SIZE, 'loss func':loss_func.__str__(), 'initial lr':lr_data }) print '''Begin to calculate.''' for epoch_idx in range(config.MAX_EPOCH): if epoch_idx%10==0: for ee in optimizer.param_groups: if ee['lr']>=1e-7: ee['lr']/=2 lr_data=ee['lr'] lrs.send('lr',lr_data) if epoch_idx>0: print 'SDR_SUM (len:{}) for epoch {} : '.format(SDR_SUM.shape,epoch_idx-1,SDR_SUM.mean()) SDR_SUM=np.array([]) train_data_gen=prepare_data('once','train') # train_data_gen=prepare_data('once','test') while 1 and True: train_data=train_data_gen.next() if train_data==False: break #如果这个epoch的生成器没有数据了,直接进入下一个epoch '''混合语音len,fre,Emb 3D表示层''' mix_speech_hidden,mix_tmp_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(train_data['mix_feas'])).cuda()) # mix_tmp_hidden:[bs*T*hidden_units] # 暂时关掉video部分,因为s2 s3 s4 的视频数据不全暂时 '''Speech self Sepration 语音自分离部分''' mix_speech_output=mix_speech_classifier(Variable(torch.from_numpy(train_data['mix_feas'])).cuda()) #从数据里得到ground truth的说话人名字和vector # y_spk_list=[one.keys() for one in train_data['multi_spk_fea_list']] y_spk_list= train_data['multi_spk_fea_list'] y_spk_gtruth,y_map_gtruth=multi_label_vector(y_spk_list,dict_spk2idx) # 如果训练阶段使用Ground truth的分离结果作为判别 if config.Ground_truth: mix_speech_output=Variable(torch.from_numpy(y_map_gtruth)).cuda() if test_all_outputchannel: #把输入的mask改成全1,可以用来测试输出所有的channel mix_speech_output=Variable(torch.ones(config.BATCH_SIZE,num_labels,)) y_map_gtruth=np.ones([config.BATCH_SIZE,num_labels]) top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=0.5,top_k=num_labels) #torch.Float型的 top_k_mask_idx=[np.where(line==1)[0] for line in top_k_mask_mixspeech.numpy()] mix_speech_multiEmbs=mix_speech_multiEmbedding(top_k_mask_mixspeech,top_k_mask_idx) # bs*num_labels(最多混合人个数)×Embedding的大小 mix_adjust=adjust_layer(mix_tmp_hidden,mix_speech_multiEmbs) mix_speech_multiEmbs=mix_adjust+mix_speech_multiEmbs assert len(top_k_mask_idx[0])==len(top_k_mask_idx[-1]) top_k_num=len(top_k_mask_idx[0]) #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了 mix_speech_hidden_5d=mix_speech_hidden.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE) mix_speech_hidden_5d=mix_speech_hidden_5d.expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre,config.EMBEDDING_SIZE).contiguous() mix_speech_hidden_5d_last=mix_speech_hidden_5d.view(-1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE) att_multi_speech=att_speech_layer(mix_speech_hidden_5d_last,mix_speech_multiEmbs.view(-1,config.EMBEDDING_SIZE)) print att_multi_speech.size() att_multi_speech=att_multi_speech.view(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre) # bs,num_labels,len,fre这个东西 # print att_multi_speech.size() multi_mask=att_multi_speech # top_k_mask_mixspeech_multi=top_k_mask_mixspeech.view(config.BATCH_SIZE,top_k_num,1,1).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre) # multi_mask=multi_mask*Variable(top_k_mask_mixspeech_multi).cuda() x_input_map=Variable(torch.from_numpy(train_data['mix_feas'])).cuda() # print x_input_map.size() x_input_map_multi=x_input_map.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre) # predict_multi_map=multi_mask*x_input_map_multi predict_multi_map=multi_mask*x_input_map_multi y_multi_map=np.zeros([config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre],dtype=np.float32) batch_spk_multi_dict=train_data['multi_spk_fea_list'] for idx,sample in enumerate(batch_spk_multi_dict): y_idx=sorted([dict_spk2idx[spk] for spk in sample.keys()]) assert y_idx==list(top_k_mask_idx[idx]) for jdx,oo in enumerate(y_idx): y_multi_map[idx,jdx]=sample[dict_idx2spk[oo]] y_multi_map= Variable(torch.from_numpy(y_multi_map)).cuda() loss_multi_speech=loss_multi_func(predict_multi_map,y_multi_map) score_true=dis_layer(y_multi_map) score_false=dis_layer(predict_multi_map) acc_true=torch.sum(score_true>0.5).data.cpu().numpy()/float(score_true.size()[0]) acc_false=torch.sum(score_false<0.5).data.cpu().numpy()/float(score_true.size()[0]) acc_dis=(acc_false+acc_true)/2 print 'acc for dis:(ture,false,aver)',acc_true,acc_false,acc_dis loss_dis_true=loss_dis_class(score_true,Variable(torch.ones(config.BATCH_SIZE*top_k_num,1)).cuda()) loss_dis_false=loss_dis_class(score_false,Variable(torch.zeros(config.BATCH_SIZE*top_k_num,1)).cuda()) loss_dis=loss_dis_true+loss_dis_false print 'loss for dis:(ture,false)',loss_dis_true.data.cpu().numpy(),loss_dis_false.data.cpu().numpy() optimizer.zero_grad() # clear gradients for next train loss_dis.backward(retain_graph=True) # backpropagation, compute gradients optimizer.step() # apply gradients #各通道和为1的loss部分,应该可以更多的带来差异 y_sum_map=Variable(torch.ones(config.BATCH_SIZE,mix_speech_len,speech_fre)).cuda() predict_sum_map=torch.sum(multi_mask,1) loss_multi_sum_speech=loss_multi_func(predict_sum_map,y_sum_map) # loss_multi_speech=loss_multi_speech #todo:以后可以研究下这个和为1的效果对比一下,暂时直接MSE效果已经很不错了。 print 'loss 1, losssum : ',loss_multi_speech.data.cpu().numpy(),loss_multi_sum_speech.data.cpu().numpy() lrs.send('loss mask:',loss_multi_speech.data.cpu()[0]) lrs.send('loss sum:',loss_multi_sum_speech.data.cpu()[0]) loss_multi_speech=loss_multi_speech+0.5*loss_multi_sum_speech print 'training multi-abs norm this batch:',torch.abs(y_multi_map-predict_multi_map).norm().data.cpu().numpy() print 'loss:',loss_multi_speech.data.cpu().numpy() loss_dis_false=loss_dis_class(score_false,Variable(torch.ones(config.BATCH_SIZE*top_k_num,1)).cuda()) loss_multi_speech=loss_multi_speech+loss_dis_false optimizer.zero_grad() # clear gradients for next train loss_multi_speech.backward() # backpropagation, compute gradients optimizer.step() # apply gradients if 1 and epoch_idx >= 10 and epoch_idx % 5 == 0: # torch.save(mix_speech_multiEmbedding.state_dict(),'params/param_mixalignag_{}_emblayer_{}'.format(config.DATASET,epoch_idx)) # torch.save(mix_hidden_layer_3d.state_dict(),'params/param_mixalignag_{}_hidden3d_{}'.format(config.DATASET,epoch_idx)) # torch.save(att_speech_layer.state_dict(),'params/param_mixalignag_{}_attlayer_{}'.format(config.DATASET,epoch_idx)) torch.save(mix_speech_multiEmbedding.state_dict(), 'params/param_mix{}ag_{}_emblayer_{}'.format(att_speech_layer.mode, config.DATASET, epoch_idx)) torch.save(mix_hidden_layer_3d.state_dict(), 'params/param_mix{}ag_{}_hidden3d_{}'.format(att_speech_layer.mode, config.DATASET, epoch_idx)) torch.save(att_speech_layer.state_dict(), 'params/param_mix{}ag_{}_attlayer_{}'.format(att_speech_layer.mode, config.DATASET, epoch_idx)) torch.save(adjust_layer.state_dict(), 'params/param_mix{}ag_{}_adjlayer_{}'.format(att_speech_layer.mode, config.DATASET, epoch_idx)) torch.save(dis_layer.state_dict(), 'params/param_mix{}ag_{}_dislayer_{}'.format(att_speech_layer.mode, config.DATASET, epoch_idx)) if 1 and epoch_idx % 3 == 0: eval_bss(mix_hidden_layer_3d,adjust_layer, mix_speech_classifier, mix_speech_multiEmbedding, att_speech_layer, loss_multi_func, dict_spk2idx, dict_idx2spk, num_labels, mix_speech_len, speech_fre)
def eval_bss(mix_hidden_layer_3d,adjust_layer,mix_speech_classifier,mix_speech_multiEmbedding,att_speech_layer, loss_multi_func,dict_spk2idx,dict_idx2spk,num_labels,mix_speech_len,speech_fre): for i in [mix_speech_multiEmbedding,adjust_layer,mix_speech_classifier,mix_hidden_layer_3d,att_speech_layer]: i.training=False print '#' * 40 eval_data_gen=prepare_data('once','valid') SDR_SUM=np.array([]) while True: print '\n\n' eval_data=eval_data_gen.next() if eval_data==False: break #如果这个epoch的生成器没有数据了,直接进入下一个epoch '''混合语音len,fre,Emb 3D表示层''' mix_speech_hidden,mix_tmp_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(eval_data['mix_feas'])).cuda()) # 暂时关掉video部分,因为s2 s3 s4 的视频数据不全暂时 '''Speech self Sepration 语音自分离部分''' mix_speech_output=mix_speech_classifier(Variable(torch.from_numpy(eval_data['mix_feas'])).cuda()) y_spk_list= eval_data['multi_spk_fea_list'] y_spk_gtruth,y_map_gtruth=multi_label_vector(y_spk_list,dict_spk2idx) # 如果训练阶段使用Ground truth的分离结果作为判别 if config.Ground_truth: mix_speech_output=Variable(torch.from_numpy(y_map_gtruth)).cuda() if test_all_outputchannel: #把输入的mask改成全1,可以用来测试输出所有的channel mix_speech_output=Variable(torch.ones(config.BATCH_SIZE,num_labels,)) y_map_gtruth=np.ones([config.BATCH_SIZE,num_labels]) top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=0.5,top_k=num_labels) #torch.Float型的 top_k_mask_idx=[np.where(line==1)[0] for line in top_k_mask_mixspeech.numpy()] mix_speech_multiEmbs=mix_speech_multiEmbedding(top_k_mask_mixspeech,top_k_mask_idx) # bs*num_labels(最多混合人个数)×Embedding的大小 mix_adjust=adjust_layer(mix_tmp_hidden,mix_speech_multiEmbs) mix_speech_multiEmbs=mix_adjust+mix_speech_multiEmbs assert len(top_k_mask_idx[0])==len(top_k_mask_idx[-1]) top_k_num=len(top_k_mask_idx[0]) #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了 mix_speech_hidden_5d=mix_speech_hidden.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE) mix_speech_hidden_5d=mix_speech_hidden_5d.expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre,config.EMBEDDING_SIZE).contiguous() mix_speech_hidden_5d_last=mix_speech_hidden_5d.view(-1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE) # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda() # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'dot').cuda() att_multi_speech=att_speech_layer(mix_speech_hidden_5d_last,mix_speech_multiEmbs.view(-1,config.EMBEDDING_SIZE)) # print att_multi_speech.size() att_multi_speech=att_multi_speech.view(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre) # bs,num_labels,len,fre这个东西 # print att_multi_speech.size() multi_mask=att_multi_speech # top_k_mask_mixspeech_multi=top_k_mask_mixspeech.view(config.BATCH_SIZE,top_k_num,1,1).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre) # multi_mask=multi_mask*Variable(top_k_mask_mixspeech_multi).cuda() x_input_map=Variable(torch.from_numpy(eval_data['mix_feas'])).cuda() # print x_input_map.size() x_input_map_multi=x_input_map.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre) # predict_multi_map=multi_mask*x_input_map_multi predict_multi_map=multi_mask*x_input_map_multi y_multi_map=np.zeros([config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre],dtype=np.float32) batch_spk_multi_dict=eval_data['multi_spk_fea_list'] for idx,sample in enumerate(batch_spk_multi_dict): y_idx=sorted([dict_spk2idx[spk] for spk in sample.keys()]) assert y_idx==list(top_k_mask_idx[idx]) for jdx,oo in enumerate(y_idx): y_multi_map[idx,jdx]=sample[dict_idx2spk[oo]] y_multi_map= Variable(torch.from_numpy(y_multi_map)).cuda() loss_multi_speech=loss_multi_func(predict_multi_map,y_multi_map) #各通道和为1的loss部分,应该可以更多的带来差异 y_sum_map=Variable(torch.ones(config.BATCH_SIZE,mix_speech_len,speech_fre)).cuda() predict_sum_map=torch.sum(multi_mask,1) loss_multi_sum_speech=loss_multi_func(predict_sum_map,y_sum_map) # loss_multi_speech=loss_multi_speech #todo:以后可以研究下这个和为1的效果对比一下,暂时直接MSE效果已经很不错了。 print 'loss 1 eval, losssum eval : ',loss_multi_speech.data.cpu().numpy(),loss_multi_sum_speech.data.cpu().numpy() lrs.send('loss mask eval:',loss_multi_speech.data.cpu()[0]) lrs.send('loss sum eval:',loss_multi_sum_speech.data.cpu()[0]) loss_multi_speech=loss_multi_speech+0.5*loss_multi_sum_speech print 'evaling multi-abs norm this eval batch:',torch.abs(y_multi_map-predict_multi_map).norm().data.cpu().numpy() print 'loss:',loss_multi_speech.data.cpu().numpy() bss_eval(predict_multi_map,y_multi_map,top_k_mask_idx,dict_idx2spk,eval_data) SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output/', 2)) print 'SDR_aver_now:',SDR_SUM.mean() SDR_aver=SDR_SUM.mean() print 'SDR_SUM (len:{}) for epoch eval : '.format(SDR_SUM.shape) lrs.send('SDR eval aver',SDR_aver) print '#'*40
def reset_grad(): for p in params: if p.grad is not None: data = p.grad.data p.grad = Variable(data.new().resize_as_(data).zero_()) G_solver = optim.Adam(G_params, lr=1e-3) D_solver = optim.Adam(D_params, lr=1e-3) ones_label = Variable(torch.ones(mb_size, 1)) zeros_label = Variable(torch.zeros(mb_size, 1)) lrs.send({'title': 'Vanilla GAN'}) for it in range(100000): # Sample data z = Variable(torch.randn(mb_size, Z_dim)) X, _ = mnist.train.next_batch(mb_size) X = Variable(torch.from_numpy(X)) # Dicriminator forward-loss-backward-update G_sample = G(z) D_real = D(X) D_fake = D(G_sample) D_loss_real = nn.binary_cross_entropy(D_real, ones_label) D_loss_fake = nn.binary_cross_entropy(D_fake, zeros_label) D_loss = D_loss_real + D_loss_fake
optimizer = torch.optim.SGD(net.parameters(), lr=0.02) cc = net.parameters() for i in cc: print i loss_func = torch.nn.CrossEntropyLoss( ) # the target label is NOT an one-hotted # loss_func = torch.nn.CrossEntropyLoss(torch.FloatTensor([99,1])) #这个数字是带权重的更新,对于数据有不平衡的时候有很好的效果 # the target label is NOT an one-hotted plt.ion() # something about plotting tt = time.time() '''lrs的测试和使用''' lrs.send({ 'title': 'Basic classifier', 'batch_size': None, 'epochs': 100, 'optimizer': 'SGD', 'lr': 0.02, 'momentum': 'Initailization' }) for t in range(100): # lrs.send('epoch',t) out = net(x) # input x and predict based on x # print out loss = loss_func( out, y ) # must be (1. nn output, 2. target), the target label is NOT one-hotted print '\nttt:', t print 'loss:', loss.data[0] lrs.send('train_loss', loss.data[0])
def main(config): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_transform = transforms.Compose([ transforms.Scale(256), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor() ]) val_transform = transforms.Compose( [transforms.Resize((224, 224)), transforms.ToTensor()]) test_transform = transforms.Compose([transforms.ToTensor()]) trainset = AVADataset(csv_file=config.train_csv_file, root_dir=config.train_img_path, transform=train_transform) valset = AVADataset(csv_file=config.val_csv_file, root_dir=config.val_img_path, transform=val_transform) train_loader = torch.utils.data.DataLoader( trainset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers) val_loader = torch.utils.data.DataLoader(valset, batch_size=config.val_batch_size, shuffle=False, num_workers=config.num_workers) # base_model = models.vgg16(pretrained=True) # base_model = models.resnet18(pretrained=True) base_model = models.resnet101(pretrained=True, progress=False) # base_model = models.inception_v3(pretrained=True) model = NIMA(base_model) # model = NIMA() if config.warm_start: model.load_state_dict( torch.load( os.path.join(config.ckpt_path, 'epoch-%d.pkl' % config.warm_start_epoch))) print('Successfully loaded model epoch-%d.pkl' % config.warm_start_epoch) if config.multi_gpu: model.features = torch.nn.DataParallel(model.features, device_ids=config.gpu_ids) model = model.to(device) else: model = model.to(device) conv_base_lr = config.conv_base_lr dense_lr = config.dense_lr optimizer = optim.SGD([{ 'params': model.features.parameters(), 'lr': conv_base_lr }, { 'params': model.classifier.parameters(), 'lr': dense_lr }], momentum=0.9) # optimizer = optim.Adam( model.parameters(), lr = conv_base_lr, betas=(0.9,0.999)) # Loss functions # criterion = torch.nn.L1Loss() criterion = torch.nn.CrossEntropyLoss() # send hyperparams lrs.send({ 'title': 'EMD Loss', 'train_batch_size': config.train_batch_size, 'val_batch_size': config.val_batch_size, 'optimizer': 'SGD', 'conv_base_lr': config.conv_base_lr, 'dense_lr': config.dense_lr, 'momentum': 0.9 }) param_num = 0 for param in model.parameters(): param_num += int(np.prod(param.shape)) print('Trainable params: %.2f million' % (param_num / 1e6)) if config.test: # start.record() print('Testing') model.load_state_dict( torch.load( os.path.join(config.ckpt_path, 'epoch-%d.pkl' % config.warm_start_epoch))) target_layer = model.features # compute mean score test_transform = test_transform #val_transform testset = AVADataset(csv_file=config.test_csv_file, root_dir=config.test_img_path, transform=val_transform) test_loader = torch.utils.data.DataLoader( testset, batch_size=config.test_batch_size, shuffle=False, num_workers=config.num_workers) ypreds = [] ylabels = [] im_ids = [] # std_preds = [] count = 0 gradcam = GradCAM(model, target_layer) for data in test_loader: im_id = data['img_id'] im_name = os.path.split(im_id[0]) myname = os.path.splitext(im_name[1]) image = data['image'].to(device) mask, _ = gradcam(image) heatmap, result = visualize_cam(mask, image) im = transforms.ToPILImage()(result) im.save(myname[0] + ".jpg") labels = data['annotations'].to(device).long() output = model(image) output = output.view(-1, 2) bpred = output.to(torch.device("cpu")) cpred = bpred.data.numpy() blabel = labels.to(torch.device("cpu")) clabel = blabel.data.numpy() # predicted_mean, predicted_std = 0.0, 0.0 # for i, elem in enumerate(output, 1): # predicted_mean += i * elem # for j, elem in enumerate(output, 1): # predicted_std += elem * (i - predicted_mean) ** 2 ypreds.append(cpred) ylabels.append(clabel) im_name = os.path.split(im_id[0]) im_ids.append(im_name[1]) count = count + 1 np.savez('Test_results_16.npz', Label=ylabels, Predict=ypreds) df = pd.DataFrame(data={'Label': ylabels, "Predict": ypreds}) print(df.dtypes) df.to_pickle("./Test_results_19_resnet.pkl")
def main(): print('go to model') print '*' * 80 spk_global_gen=prepare_data(mode='global',train_or_test='train') #写一个假的数据生成,可以用来写模型先 global_para=spk_global_gen.next() print global_para spk_all_list,dict_spk2idx,dict_idx2spk,mix_speech_len,speech_fre,total_frames,spk_num_total,batch_total=global_para del spk_global_gen num_labels=len(spk_all_list) config.EPOCH_SIZE=batch_total #此处顺序是 mix_speechs.shape,mix_feas.shape,aim_fea.shape,aim_spkid.shape,query.shape #一个例子:(5, 17040) (5, 134, 129) (5, 134, 129) (5,) (5, 32, 400, 300, 3) print 'Begin to build the maim model for Multi_Modal Cocktail Problem.' mix_speech_class=MIX_SPEECH_classifier(speech_fre,mix_speech_len,num_labels).cuda() print mix_speech_class if 0 and config.Load_param: # para_name='param_speech_WSJ0_multilabel_epoch42' # para_name='param_speech_WSJ0_multilabel_epoch249' # para_name='param_speech_123_WSJ0_multilabel_epoch75' # para_name='param_speech_123_WSJ0_multilabel_epoch24' para_name='param_speech_123onezero_WSJ0_multilabel_epoch75' #top3 召回率80% para_name='param_speech_123onezeroag_WSJ0_multilabel_epoch80'#83.6 para_name='param_speech_123onezeroag1_WSJ0_multilabel_epoch45' para_name='param_speech_123onezeroag2_WSJ0_multilabel_epoch40' para_name='param_speech_123onezeroag4_WSJ0_multilabel_epoch75' para_name='param_speech_123onezeroag3_WSJ0_multilabel_epoch40' para_name='param_speech_123onezeroag4_WSJ0_multilabel_epoch20' para_name='param_speech_4lstm_multilabelloss30map_epoch440' para_name='param_speech_123onezeroag5dropout_WSJ0_multilabel_epoch20' # mix_speech_class.load_state_dict(torch.load('params/param_speech_multilabel_epoch249')) mix_speech_class.load_state_dict(torch.load('params/{}'.format(para_name))) print 'Load Success:',para_name lr_data=0.0001 optimizer = torch.optim.Adam([{'params':mix_speech_class.parameters()}, # {'params':query_video_layer.lstm_layer.parameters()}, # {'params':query_video_layer.dense.parameters()}, # {'params':query_video_layer.Linear.parameters()}, # {'params':att_layer.parameters()}, # ], lr=0.02,momentum=0.9) ], lr=lr_data) loss_func = torch.nn.KLDivLoss() # the target label is NOT an one-hotted # loss_func = torch.nn.MultiLabelSoftMarginLoss() # the target label is NOT an one-hotted # loss_func = torch.nn.MSELoss() # the target label is NOT an one-hotted # loss_func = torch.nn.CrossEntropyLoss() # the target label is NOT an one-hotted # loss_func = torch.nn.MultiLabelMarginLoss() # the target label is NOT an one-hotted # loss_func = torch.nn.L1Loss() # the target label is NOT an one-hotted lrs.send({ 'title': 'TDAA classifier', 'batch_size':config.BATCH_SIZE, 'batch_total':batch_total, 'epoch_size':config.EPOCH_SIZE, 'loss func':loss_func.__str__(), 'initial lr':lr_data }) print '''Begin to calculate.''' for epoch_idx in range(config.MAX_EPOCH): if epoch_idx%10==0: for ee in optimizer.param_groups: ee['lr']/=2 lr_data=ee['lr'] acc_all,acc_line=0,0 if epoch_idx>0: print 'recal_rate this epoch {}: {}'.format(epoch_idx,recall_rate_list.mean()) recall_rate_list=np.array([]) train_data_gen=prepare_data('once','train') for batch_idx in range(config.EPOCH_SIZE): continue print '*' * 40,epoch_idx,batch_idx,'*'*40 train_data=train_data_gen.next() if train_data==False: break #如果这个epoch的生成器没有数据了,直接进入下一个epoch mix_speech=mix_speech_class(Variable(torch.from_numpy(train_data['mix_feas'])).cuda()) y_spk,y_map=multi_label_vector(train_data['multi_spk_fea_list'],dict_spk2idx) y_map=Variable(torch.from_numpy(y_map)).cuda() y_out_batch=mix_speech.data.cpu().numpy() acc1,acc2,all_num_batch,all_line_batch,recall_rate=count_multi_acc(y_out_batch,y_spk,alpha=-0.1,top_k_num=2) acc_all+=acc1 acc_line+=acc2 recall_rate_list=np.append(recall_rate_list,recall_rate) # print 'training abs norm this batch:',torch.abs(y_map-predict_map).norm().data.cpu().numpy() for i in range(config.BATCH_SIZE): print 'aim:{}-->{},predict:{}'.format(train_data['multi_spk_fea_list'][i].keys(),y_spk[i],mix_speech.data.cpu().numpy()[i][y_spk[i]])#除了输出目标的几个概率,也输出倒数四个的 print 'last 4 probility:{}'.format(mix_speech.data.cpu().numpy()[i][-5:])#除了输出目标的几个概率,也输出倒数四个的 print '\nAcc for this batch: all elements({}) acc--{},all sample({}) acc--{} recall--{}'.format(all_num_batch,acc1,all_line_batch,acc2,recall_rate) lrs.send('recall_rate',recall_rate) # continue # if epoch_idx==0 and batch_idx<50: # loss=loss_func(mix_speech,100*y_map) # else: # loss=loss_func(mix_speech,y_map) # loss=loss_func(mix_speech,30*y_map) loss=loss_func(mix_speech,y_map) loss_sum=loss_func(mix_speech.sum(1),y_map.sum(1)) lrs.send('train_loss',loss.data[0]) print 'loss this batch:',loss.data.cpu().numpy(),loss_sum.data.cpu().numpy() print 'time:',time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) # continue # loss=loss+0.2*loss_sum optimizer.zero_grad() # clear gradients for next train loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients lrs.send('sum_loss',loss_sum.data[0]) lrs.send('lr',lr_data) if config.Save_param and epoch_idx >= 10 and epoch_idx % 2 == 0: try: torch.save(mix_speech_class.state_dict(), 'params/param_speech_123onezeroag5dropout_{}_multilabel_epoch{}'.format(config.DATASET,epoch_idx)) except: print '\n\nSave paras failed ~! \n\n\n' # Print the Params history , that it proves well. # print 'Parameter history:' # for pa_gen in [{'params':mix_hidden_layer_3d.parameters()}, # {'params':query_video_layer.lstm_layer.parameters()}, # {'params':query_video_layer.dense.parameters()}, # {'params':query_video_layer.Linear.parameters()}, # {'params':att_layer.parameters()}, # ]: # print pa_gen['params'].next() print 'Acc for this epoch: all elements acc--{},all sample acc--{}'.format(acc_all/config.EPOCH_SIZE,acc_line/config.EPOCH_SIZE) print '\nBegin to evaluate.' eval_model(mix_speech_class,dict_spk2idx,loss_func)
def main(): print('go to model') print '*' * 80 spk_global_gen=prepare_data(mode='global',train_or_test='train') #写一个假的数据生成,可以用来写模型先 global_para=spk_global_gen.next() print global_para spk_all_list,dict_spk2idx,dict_idx2spk,mix_speech_len,speech_fre,total_frames,spk_num_total,batch_total=global_para del spk_global_gen num_labels=len(spk_all_list) print 'Begin to build the maim model for Multi_Modal Cocktail Problem.' # This part is to build the 3D mix speech embedding maps. mix_hidden_layer_3d=MIX_SPEECH(speech_fre,mix_speech_len).cuda() mix_speech_classifier=MIX_SPEECH_classifier(speech_fre,mix_speech_len,num_labels).cuda() mix_speech_multiEmbedding=SPEECH_EMBEDDING(num_labels,config.EMBEDDING_SIZE,spk_num_total+config.UNK_SPK_SUPP).cuda() print mix_hidden_layer_3d print mix_speech_classifier print mix_speech_multiEmbedding att_layer=ATTENTION(config.EMBEDDING_SIZE,'dot').cuda() att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'dot').cuda() adjust_layer=ADDJUST(2*config.HIDDEN_UNITS,config.EMBEDDING_SIZE) dis_layer=Discriminator().cuda() print att_speech_layer print att_speech_layer.mode print adjust_layer print dis_layer lr_data=0.0002 optimizer = torch.optim.Adam([{'params':mix_hidden_layer_3d.parameters()}, {'params':mix_speech_multiEmbedding.parameters()}, {'params':mix_speech_classifier.parameters()}, {'params':adjust_layer.parameters()}, {'params':att_speech_layer.parameters()}, {'params':dis_layer.parameters()}, ], lr=lr_data) if 1 and config.Load_param: class_dict=torch.load('params/param_speech_2mix3lstm_best',map_location={'cuda:3':'cuda:0'}) for key in class_dict.keys(): if 'cnn' in key: class_dict.pop(key) mix_speech_classifier.load_state_dict(class_dict) # 底下四个是TDAA-basic最强版本 # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mixdotadjust4lstmdot_WSJ0_hidden3d_125',map_location={'cuda:1':'cuda:0'})) # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mixdotadjust4lstmdot_WSJ0_emblayer_125',map_location={'cuda:1':'cuda:0'})) # att_speech_layer.load_state_dict(torch.load('params/param_mixdotadjust4lstmdot_WSJ0_attlayer_125',map_location={'cuda:1':'cuda:0'})) # adjust_layer.load_state_dict(torch.load('params/param_mixdotadjust4lstmdot_WSJ0_adjlayer_125',map_location={'cuda:1':'cuda:0'})) #加入dis-ss的结果 # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mixdotadjust4lstmdotdis_3434.13436424_hidden3d_395',map_location={'cuda:2':'cuda:0'})) # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mixdotadjust4lstmdotdis_3434.13436424_emblayer_395',map_location={'cuda:2':'cuda:0'})) # att_speech_layer.load_state_dict(torch.load('params/param_mixdotadjust4lstmdotdis_3434.13436424_attlayer_395',map_location={'cuda:2':'cuda:0'})) # adjust_layer.load_state_dict(torch.load('params/param_mixdotadjust4lstmdotdis_3434.13436424_adjlayer_395',map_location={'cuda:2':'cuda:0'})) #加入dis-sp的结果 mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mixdotadjust4lstmdotdissp_33401_hidden3d_185',map_location={'cuda:1':'cuda:0'})) mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mixdotadjust4lstmdotdissp_33401_emblayer_185',map_location={'cuda:1':'cuda:0'})) att_speech_layer.load_state_dict(torch.load('params/param_mixdotadjust4lstmdotdissp_33401_attlayer_185',map_location={'cuda:1':'cuda:0'})) adjust_layer.load_state_dict(torch.load('params/param_mixdotadjust4lstmdotdissp_33401_adjlayer_185',map_location={'cuda:1':'cuda:0'})) loss_func = torch.nn.MSELoss() # the target label is NOT an one-hotted loss_multi_func = torch.nn.MSELoss() # the target label is NOT an one-hotted # loss_multi_func = torch.nn.L1Loss() # the target label is NOT an one-hotted loss_dis_class=torch.nn.MSELoss() lrs.send({ 'title': 'TDAA classifier', 'batch_size':config.BATCH_SIZE, 'batch_total':batch_total, 'epoch_size':config.EPOCH_SIZE, 'loss func':loss_func.__str__(), 'initial lr':lr_data }) print '''Begin to calculate.''' for epoch_idx in range(1): if epoch_idx%10==0: for ee in optimizer.param_groups: if ee['lr']>=1e-7: ee['lr']/=2 lr_data=ee['lr'] lrs.send('lr',lr_data) if epoch_idx>0: print 'SDR_SUM (len:{}) for epoch {} : '.format(SDR_SUM.shape,epoch_idx-1,SDR_SUM.mean()) SDR_SUM=np.array([]) eval_data_gen=prepare_data('once','valid') # eval_data_gen=prepare_data('once','test') while 1 and True: print '\n' eval_data=eval_data_gen.next() if eval_data==False: break #如果这个epoch的生成器没有数据了,直接进入下一个epoch now_data=eval_data['mix_feas'] top_k_num=3 # while True: candidates=[] predict_multi_map=np.zeros([config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre],dtype=np.float32) for ____ in range(3): print 'Recu step:',____ out_this_step,spk_this_step=model_step_output(now_data,mix_speech_classifier,mix_hidden_layer_3d,\ mix_speech_multiEmbedding,adjust_layer,att_speech_layer,\ dict_spk2idx,dict_idx2spk,num_labels,mix_speech_len,speech_fre) out_this_step=out_this_step[0].data.cpu().numpy() predict_multi_map[0,____]=out_this_step now_data=now_data-out_this_step candidates.append(spk_this_step) y_multi_map=np.zeros([config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre],dtype=np.float32) batch_spk_multi_dict=eval_data['multi_spk_fea_list'] if test_mode: for iiii in range(config.BATCH_SIZE): y_multi_map[iiii]=np.array(batch_spk_multi_dict[iiii].values()) y_multi_map= Variable(torch.from_numpy(y_multi_map)).cuda() if 0: #这个是只利用推断出来的spk,回去做分离 print 'Recu only for spks.' top_mask=torch.zeros(num_labels) for jjj in candidates: top_mask[int(jjj[0])]=1 top_mask=top_mask.view(1,num_labels) try: ccc=eval_bss(top_mask,eval_data,mix_hidden_layer_3d,adjust_layer, mix_speech_classifier, mix_speech_multiEmbedding, att_speech_layer, loss_multi_func, dict_spk2idx, dict_idx2spk, num_labels, mix_speech_len, speech_fre) SDR_SUM = np.append(SDR_SUM, ccc) except: pass else: print 'Recu for spks and maps.' predict_multi_map=Variable(torch.from_numpy(predict_multi_map)).cuda() try: bss_eval(predict_multi_map,y_multi_map,2,dict_idx2spk,eval_data) SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output2/', 2)) except: pass if SDR_SUM[-3:].mean()>8: raw_input() print 'SDR_aver_now:',SDR_SUM.mean() print 'SDR_SUM (len:{}) for epoch eval : '.format(SDR_SUM.shape) print '#'*40
def eval_bss(candidates,eval_data,mix_hidden_layer_3d,adjust_layer,mix_speech_classifier,mix_speech_multiEmbedding,att_speech_layer, loss_multi_func,dict_spk2idx,dict_idx2spk,num_labels,mix_speech_len,speech_fre): for i in [mix_speech_multiEmbedding,adjust_layer,mix_speech_classifier,mix_hidden_layer_3d,att_speech_layer]: i.evaling=False fea_now=eval_data['mix_feas'] while True: '''混合语音len,fre,Emb 3D表示层''' mix_speech_hidden,mix_tmp_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(fea_now)).cuda()) # 暂时关掉video部分,因为s2 s3 s4 的视频数据不全暂时 '''Speech self Sepration 语音自分离部分''' # mix_speech_output=mix_speech_classifier(Variable(torch.from_numpy(fea_now)).cuda()) if test_mode: num_labels=2 alpha0=-0.5 else: alpha0=0.5 # top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=alpha0,top_k=num_labels) #torch.Float型的 top_k_mask_mixspeech=candidates #torch.Float型的 top_k_mask_idx=[np.where(line==1)[0] for line in top_k_mask_mixspeech.numpy()] print 'Predict spk list:',print_spk_name(dict_idx2spk,top_k_mask_idx) mix_speech_multiEmbs=mix_speech_multiEmbedding(top_k_mask_mixspeech,top_k_mask_idx) # bs*num_labels(最多混合人个数)×Embedding的大小 mix_adjust=adjust_layer(mix_tmp_hidden,mix_speech_multiEmbs) mix_speech_multiEmbs=mix_adjust+mix_speech_multiEmbs assert len(top_k_mask_idx[0])==len(top_k_mask_idx[-1]) top_k_num=len(top_k_mask_idx[0]) #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了 mix_speech_hidden_5d=mix_speech_hidden.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE) mix_speech_hidden_5d=mix_speech_hidden_5d.expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre,config.EMBEDDING_SIZE).contiguous() mix_speech_hidden_5d_last=mix_speech_hidden_5d.view(-1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE) # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda() # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'dot').cuda() att_multi_speech=att_speech_layer(mix_speech_hidden_5d_last,mix_speech_multiEmbs.view(-1,config.EMBEDDING_SIZE)) # print att_multi_speech.size() att_multi_speech=att_multi_speech.view(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre) # bs,num_labels,len,fre这个东西 # print att_multi_speech.size() multi_mask=att_multi_speech # top_k_mask_mixspeech_multi=top_k_mask_mixspeech.view(config.BATCH_SIZE,top_k_num,1,1).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre) # multi_mask=multi_mask*Variable(top_k_mask_mixspeech_multi).cuda() x_input_map=Variable(torch.from_numpy(fea_now)).cuda() # print x_input_map.size() x_input_map_multi=x_input_map.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre) # predict_multi_map=multi_mask*x_input_map_multi predict_multi_map=multi_mask*x_input_map_multi y_multi_map=np.zeros([config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre],dtype=np.float32) batch_spk_multi_dict=eval_data['multi_spk_fea_list'] if test_mode: for iiii in range(config.BATCH_SIZE): y_multi_map[iiii]=np.array(batch_spk_multi_dict[iiii].values()) else: for idx,sample in enumerate(batch_spk_multi_dict): y_idx=sorted([dict_spk2idx[spk] for spk in sample.keys()]) if not test_mode: assert y_idx==list(top_k_mask_idx[idx]) for jdx,oo in enumerate(y_idx): y_multi_map[idx,jdx]=sample[dict_idx2spk[oo]] y_multi_map= Variable(torch.from_numpy(y_multi_map)).cuda() loss_multi_speech=loss_multi_func(predict_multi_map,y_multi_map) #各通道和为1的loss部分,应该可以更多的带来差异 y_sum_map=Variable(torch.ones(config.BATCH_SIZE,mix_speech_len,speech_fre)).cuda() predict_sum_map=torch.sum(multi_mask,1) loss_multi_sum_speech=loss_multi_func(predict_sum_map,y_sum_map) # loss_multi_speech=loss_multi_speech #todo:以后可以研究下这个和为1的效果对比一下,暂时直接MSE效果已经很不错了。 print 'loss 1 eval, losssum eval : ',loss_multi_speech.data.cpu().numpy(),loss_multi_sum_speech.data.cpu().numpy() lrs.send('loss mask eval:',loss_multi_speech.data.cpu()[0]) lrs.send('loss sum eval:',loss_multi_sum_speech.data.cpu()[0]) loss_multi_speech=loss_multi_speech+0.5*loss_multi_sum_speech print 'evaling multi-abs norm this eval batch:',torch.abs(y_multi_map-predict_multi_map).norm().data.cpu().numpy() print 'loss:',loss_multi_speech.data.cpu().numpy() bss_eval(predict_multi_map,y_multi_map,top_k_mask_idx,dict_idx2spk,eval_data) return bss_test.cal('batch_output2/',2)