示例#1
0
def main():
    print('go to model')
    print '*' * 80

    spk_global_gen = prepare_data(mode='global',
                                  train_or_test='train')  #写一个假的数据生成,可以用来写模型先
    global_para = spk_global_gen.next()
    print global_para
    spk_all_list, dict_spk2idx, dict_idx2spk, mix_speech_len, speech_fre, total_frames, spk_num_total, batch_total = global_para
    del spk_global_gen
    num_labels = len(spk_all_list)

    print 'Begin to build the maim model for Multi_Modal Cocktail Problem.'

    # This part is to build the 3D mix speech embedding maps.
    mix_hidden_layer_3d = MIX_SPEECH(speech_fre, mix_speech_len).cuda()
    mix_speech_classifier = MIX_SPEECH_classifier(speech_fre, mix_speech_len,
                                                  num_labels).cuda()
    mix_speech_multiEmbedding = SPEECH_EMBEDDING(
        num_labels, config.EMBEDDING_SIZE,
        spk_num_total + config.UNK_SPK_SUPP).cuda()
    print mix_hidden_layer_3d
    print mix_speech_classifier
    print mix_speech_multiEmbedding
    att_speech_layer = ATTENTION(config.EMBEDDING_SIZE, 'dot').cuda()
    adjust_layer = ADDJUST(2 * config.HIDDEN_UNITS, config.EMBEDDING_SIZE)
    print att_speech_layer
    print att_speech_layer.mode
    print adjust_layer
    lr_data = 0.0002
    optimizer = torch.optim.Adam([
        {
            'params': mix_hidden_layer_3d.parameters()
        },
        {
            'params': mix_speech_multiEmbedding.parameters()
        },
        {
            'params': mix_speech_classifier.parameters()
        },
        {
            'params': adjust_layer.parameters()
        },
        {
            'params': att_speech_layer.parameters()
        },
    ],
                                 lr=lr_data)
    if 0 and config.Load_param:
        mix_hidden_layer_3d.load_state_dict(
            torch.load('params/param_mix101_WSJ0_hidden3d_180'))
        mix_speech_multiEmbedding.load_state_dict(
            torch.load('params/param_mix101_WSJ0_emblayer_180'))
        att_speech_layer.load_state_dict(
            torch.load('params/param_mix101_WSJ0_attlayer_180'))
        adjust_layer.load_state_dict(
            torch.load('params/param_mix101_WSJ0_attlayer_180'))
    loss_func = torch.nn.MSELoss()  # the target label is NOT an one-hotted
    loss_multi_func = torch.nn.MSELoss(
    )  # the target label is NOT an one-hotted
    # loss_multi_func = torch.nn.L1Loss()  # the target label is NOT an one-hotted
    loss_query_class = torch.nn.CrossEntropyLoss()

    print '''Begin to calculate.'''
    for epoch_idx in range(config.MAX_EPOCH):
        if 0 and epoch_idx % 50 == 0:
            for ee in optimizer.param_groups:
                if ee['lr'] >= 5e-6:
                    ee['lr'] /= 2
                lr_data = ee['lr']
            print 'now lr is :', lr_data
        if epoch_idx > 0:
            print 'SDR_SUM (len:{}) for epoch {} : {}'.format(
                SDR_SUM.shape, epoch_idx - 1, SDR_SUM.mean())
        SDR_SUM = np.array([])
        train_data_gen = prepare_data('once', 'train')
        # train_data_gen=prepare_data('once','test')
        batch_idx = 0
        while 1 and True:
            print '*' * 30, epoch_idx, batch_idx, '*' * 30
            train_data = train_data_gen.next()
            if train_data == False:
                break  #如果这个epoch的生成器没有数据了,直接进入下一个epoch
            '''混合语音len,fre,Emb 3D表示层'''
            mix_speech_hidden, mix_tmp_hidden = mix_hidden_layer_3d(
                Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            # mix_tmp_hidden:[bs*T*hidden_units]
            # 暂时关掉video部分,因为s2 s3 s4 的视频数据不全暂时
            '''Speech self Sepration 语音自分离部分'''
            mix_speech_output = mix_speech_classifier(
                Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            #从数据里得到ground truth的说话人名字和vector
            # y_spk_list=[one.keys() for one in train_data['multi_spk_fea_list']]
            y_spk_list = train_data['multi_spk_fea_list']
            y_spk_gtruth, y_map_gtruth = multi_label_vector(
                y_spk_list, dict_spk2idx)
            # 如果训练阶段使用Ground truth的分离结果作为判别
            if config.Ground_truth:
                mix_speech_output = Variable(
                    torch.from_numpy(y_map_gtruth)).cuda()
                if test_all_outputchannel:  #把输入的mask改成全1,可以用来测试输出所有的channel
                    mix_speech_output = Variable(
                        torch.ones(
                            config.BATCH_SIZE,
                            num_labels,
                        ))
                    y_map_gtruth = np.ones([config.BATCH_SIZE, num_labels])

            top_k_mask_mixspeech = top_k_mask(mix_speech_output,
                                              alpha=0.5,
                                              top_k=num_labels)  #torch.Float型的
            top_k_mask_idx = [
                np.where(line == 1)[0]
                for line in top_k_mask_mixspeech.numpy()
            ]
            mix_speech_multiEmbs = mix_speech_multiEmbedding(
                top_k_mask_mixspeech,
                top_k_mask_idx)  # bs*num_labels(最多混合人个数)×Embedding的大小
            if config.is_SelfTune:
                mix_adjust = adjust_layer(mix_tmp_hidden, mix_speech_multiEmbs)
                mix_speech_multiEmbs = mix_adjust + mix_speech_multiEmbs

            assert len(top_k_mask_idx[0]) == len(top_k_mask_idx[-1])
            top_k_num = len(top_k_mask_idx[0])

            #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
            #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了
            mix_speech_hidden_5d = mix_speech_hidden.view(
                config.BATCH_SIZE, 1, mix_speech_len, speech_fre,
                config.EMBEDDING_SIZE)
            mix_speech_hidden_5d = mix_speech_hidden_5d.expand(
                config.BATCH_SIZE, top_k_num, mix_speech_len, speech_fre,
                config.EMBEDDING_SIZE).contiguous()
            mix_speech_hidden_5d_last = mix_speech_hidden_5d.view(
                -1, mix_speech_len, speech_fre, config.EMBEDDING_SIZE)
            if not config.is_ComlexMask:
                mix_speech_multiEmbs = mix_speech_multiEmbs.view(
                    -1, config.EMBEDDING_SIZE)
            else:
                mix_speech_multiEmbs = mix_speech_multiEmbs.view(
                    -1, 2 * config.EMBEDDING_SIZE)
            att_multi_speech = att_speech_layer(mix_speech_hidden_5d_last,
                                                mix_speech_multiEmbs)
            print att_multi_speech.size()

            if not config.is_ComlexMask:
                att_multi_speech = att_multi_speech.view(
                    config.BATCH_SIZE, top_k_num, mix_speech_len,
                    speech_fre)  # bs,num_labels,len,fre这个东西
            else:
                att_multi_speech = att_multi_speech.view(
                    config.BATCH_SIZE, top_k_num, mix_speech_len, speech_fre,
                    2)  # bs,num_labels,len,fre,2这个东西
                att_multi_speech = -1 / cRM_C * torch.log(
                    (cRM_k - att_multi_speech) / (cRM_k + att_multi_speech))

            multi_mask = att_multi_speech
            if not config.is_ComlexMask:
                x_input_map = Variable(torch.from_numpy(
                    train_data['mix_feas'])).cuda()
                # print x_input_map.size()
                x_input_map_multi = x_input_map.view(
                    config.BATCH_SIZE, 1, mix_speech_len,
                    speech_fre).expand(config.BATCH_SIZE, top_k_num,
                                       mix_speech_len, speech_fre)
                # predict_multi_map=multi_mask*x_input_map_multi
                predict_multi_map = multi_mask * x_input_map_multi

                y_multi_map = np.zeros(
                    [config.BATCH_SIZE, top_k_num, mix_speech_len, speech_fre],
                    dtype=np.float32)
                batch_spk_multi_dict = train_data['multi_spk_fea_list']
                for idx, sample in enumerate(batch_spk_multi_dict):
                    y_idx = sorted(
                        [dict_spk2idx[spk] for spk in sample.keys()])
                    assert y_idx == list(top_k_mask_idx[idx])
                    for jdx, oo in enumerate(y_idx):
                        y_multi_map[idx, jdx] = sample[dict_idx2spk[oo]]
                y_multi_map = Variable(torch.from_numpy(y_multi_map)).cuda()

                loss_multi_speech = loss_multi_func(predict_multi_map,
                                                    y_multi_map)

                #各通道和为1的loss部分,应该可以更多的带来差异
                y_sum_map = Variable(
                    torch.ones(config.BATCH_SIZE, mix_speech_len,
                               speech_fre)).cuda()
                predict_sum_map = torch.sum(multi_mask, 1)
                loss_multi_sum_speech = loss_multi_func(
                    predict_sum_map, y_sum_map)
                # loss_multi_speech=loss_multi_speech #todo:以后可以研究下这个和为1的效果对比一下,暂时直接MSE效果已经很不错了。
                print 'loss 1, losssum : ', loss_multi_speech.data.cpu().numpy(
                ), loss_multi_sum_speech.data.cpu().numpy()
                loss_multi_speech = loss_multi_speech + 0.5 * loss_multi_sum_speech
                print 'training multi-abs norm this batch:', torch.abs(
                    y_multi_map - predict_multi_map).norm().data.cpu().numpy()
                print 'loss:', loss_multi_speech.data.cpu().numpy()

            else:
                x_input_map = Variable(torch.from_numpy(
                    train_data['mix_mag'])).cuda()  # bs,len,fre,2
                x_input_map_multi = x_input_map.view(
                    config.BATCH_SIZE, 1, mix_speech_len, speech_fre,
                    2).expand(config.BATCH_SIZE, top_k_num, mix_speech_len,
                              speech_fre, 2)

                multi_mask_real = multi_mask[:, :, :, :, 0]
                multi_mask_fake = multi_mask[:, :, :, :, 1]
                x_input_map_real = x_input_map_multi[:, :, :, :, 0]
                x_input_map_fake = x_input_map_multi[:, :, :, :, 1]
                predict_map_real = multi_mask_real * x_input_map_real - multi_mask_fake * x_input_map_fake
                predict_map_fake = multi_mask_real * x_input_map_fake + multi_mask_fake * x_input_map_real

                y_multi_map = np.zeros([
                    config.BATCH_SIZE, top_k_num, mix_speech_len, speech_fre, 2
                ],
                                       dtype=np.float32)
                batch_spk_multi_dict = train_data['multi_spk_fea_list']
                for idx, sample in enumerate(batch_spk_multi_dict):
                    y_idx = sorted(
                        [dict_spk2idx[spk] for spk in sample.keys()])
                    assert y_idx == list(top_k_mask_idx[idx])
                    for jdx, oo in enumerate(y_idx):
                        y_multi_map[idx, jdx] = sample[dict_idx2spk[oo]]
                y_multi_map = Variable(torch.from_numpy(y_multi_map)).cuda()
                y_multi_map_real = y_multi_map[:, :, :, :, 0]
                y_multi_map_fake = y_multi_map[:, :, :, :, 1]

                loss_multi_speech_real = loss_multi_func(
                    predict_map_real, y_multi_map_real)
                loss_multi_speech_fake = loss_multi_func(
                    predict_map_fake, y_multi_map_fake)
                loss_multi_speech = loss_multi_speech_fake + loss_multi_speech_real

                print 'loss_real:', loss_multi_speech_real.data.cpu().numpy()
                print 'loss_fake:', loss_multi_speech_fake.data.cpu().numpy()
                print 'loss:', loss_multi_speech.data.cpu().numpy()

            optimizer.zero_grad()  # clear gradients for next train
            loss_multi_speech.backward()  # backpropagation, compute gradients
            optimizer.step()  # apply gradients
            batch_idx += 1

        if 1 and epoch_idx >= 10 and epoch_idx % 5 == 0:
            # torch.save(mix_speech_multiEmbedding.state_dict(),'params/param_mixalignag_{}_emblayer_{}'.format(config.DATASET,epoch_idx))
            # torch.save(mix_hidden_layer_3d.state_dict(),'params/param_mixalignag_{}_hidden3d_{}'.format(config.DATASET,epoch_idx))
            # torch.save(att_speech_layer.state_dict(),'params/param_mixalignag_{}_attlayer_{}'.format(config.DATASET,epoch_idx))
            torch.save(
                mix_speech_multiEmbedding.state_dict(),
                'params/param_mix{}ag_{}_emblayer_{}'.format(
                    att_speech_layer.mode, config.DATASET, epoch_idx))
            torch.save(
                mix_hidden_layer_3d.state_dict(),
                'params/param_mix{}ag_{}_hidden3d_{}'.format(
                    att_speech_layer.mode, config.DATASET, epoch_idx))
            torch.save(
                att_speech_layer.state_dict(),
                'params/param_mix{}ag_{}_attlayer_{}'.format(
                    att_speech_layer.mode, config.DATASET, epoch_idx))
            torch.save(
                adjust_layer.state_dict(),
                'params/param_mix{}ag_{}_attlayer_{}'.format(
                    att_speech_layer.mode, config.DATASET, epoch_idx))
        if 1 and epoch_idx % 3 == 0:
            eval_bss(mix_hidden_layer_3d, adjust_layer, mix_speech_classifier,
                     mix_speech_multiEmbedding, att_speech_layer,
                     loss_multi_func, dict_spk2idx, dict_idx2spk, num_labels,
                     mix_speech_len, speech_fre)
示例#2
0
def main():
    print('go to model')
    print '*' * 80

    spk_global_gen = prepare_data(mode='global',
                                  train_or_test='train')  #写一个假的数据生成,可以用来写模型先
    global_para = spk_global_gen.next()
    print global_para
    spk_all_list, dict_spk2idx, dict_idx2spk, mix_speech_len, speech_fre, total_frames, spk_num_total = global_para
    del spk_global_gen
    num_labels = len(spk_all_list)

    # data_generator=prepare_data('once','train')
    # data_generator=prepare_data_fake(train_or_test='train',num_labels=num_labels) #写一个假的数据生成,可以用来写模型先

    #此处顺序是 mix_speechs.shape,mix_feas.shape,aim_fea.shape,aim_spkid.shape,query.shape
    #一个例子:(5, 17040) (5, 134, 129) (5, 134, 129) (5,) (5, 32, 400, 300, 3)
    # datasize=prepare_datasize(data_generator)
    # mix_speech_len,speech_fre,total_frames,spk_num_total,video_size=datasize
    print 'Begin to build the maim model for Multi_Modal Cocktail Problem.'
    # data=data_generator.next()

    # This part is to build the 3D mix speech embedding maps.
    mix_hidden_layer_3d = MIX_SPEECH(speech_fre, mix_speech_len).cuda()
    mix_speech_classifier = MIX_SPEECH_classifier(speech_fre, mix_speech_len,
                                                  num_labels).cuda()
    mix_speech_multiEmbedding = SPEECH_EMBEDDING(
        num_labels, config.EMBEDDING_SIZE,
        spk_num_total + config.UNK_SPK_SUPP).cuda()
    print mix_hidden_layer_3d
    print mix_speech_classifier
    # mix_speech_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(data[1])).cuda())

    # mix_speech_output=mix_speech_classifier(Variable(torch.from_numpy(data[1])).cuda())
    # 技巧:alpha0的时候,就是选出top_k,top_k很大的时候,就是选出来大于alpha的
    # top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=config.ALPHA,top_k=config.MAX_MIX)
    # top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=config.ALPHA,top_k=3)
    # print top_k_mask_mixspeech
    # mix_speech_multiEmbs=mix_speech_multiEmbedding(top_k_mask_mixspeech) # bs*num_labels(最多混合人个数)×Embedding的大小
    # mix_speech_multiEmbs=mix_speech_multiEmbedding(Variable(torch.from_numpy(top_k_mask_mixspeech),requires_grad=False).cuda()) # bs*num_labels(最多混合人个数)×Embedding的大小

    # 需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
    # 把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了
    # mix_speech_hidden_5d=mix_speech_hidden.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
    # mix_speech_hidden_5d=mix_speech_hidden_5d.expand(config.BATCH_SIZE,num_labels,mix_speech_len,speech_fre,config.EMBEDDING_SIZE).contiguous()
    # mix_speech_hidden_5d=mix_speech_hidden_5d.view(-1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
    # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
    # att_multi_speech=att_speech_layer(mix_speech_hidden_5d,mix_speech_multiEmbs.view(-1,config.EMBEDDING_SIZE))
    # print att_multi_speech.size()
    # att_multi_speech=att_multi_speech.view(config.BATCH_SIZE,num_labels,mix_speech_len,speech_fre,-1)
    # print att_multi_speech.size()

    # This part is to conduct the video inputs.
    # query_video_layer=VIDEO_QUERY(total_frames,config.VideoSize,spk_num_total).cuda()
    query_video_layer = None
    # print query_video_layer
    # query_video_output,xx=query_video_layer(Variable(torch.from_numpy(data[4])))

    # This part is to conduct the memory.
    # hidden_size=(config.HIDDEN_UNITS)
    hidden_size = (config.EMBEDDING_SIZE)
    # x=torch.arange(0,24).view(2,3,4)
    # y=torch.ones([2,4])
    att_layer = ATTENTION(config.EMBEDDING_SIZE, 'align').cuda()
    att_speech_layer = ATTENTION(config.EMBEDDING_SIZE, 'align').cuda()
    # att=ATTENTION(4,'align')
    # mask=att(x,y)#bs*max_len

    # del data_generator
    # del data

    optimizer = torch.optim.Adam(
        [
            {
                'params': mix_hidden_layer_3d.parameters()
            },
            {
                'params': mix_speech_multiEmbedding.parameters()
            },
            {
                'params': mix_speech_classifier.parameters()
            },
            # {'params':query_video_layer.lstm_layer.parameters()},
            # {'params':query_video_layer.dense.parameters()},
            # {'params':query_video_layer.Linear.parameters()},
            {
                'params': att_layer.parameters()
            },
            {
                'params': att_speech_layer.parameters()
            },
            # ], lr=0.02,momentum=0.9)
        ],
        lr=0.00005)
    if 1 and config.Load_param:
        # query_video_layer.load_state_dict(torch.load('param_video_layer_19'))
        # mix_speech_classifier.load_state_dict(torch.load('params/param_speech_123onezeroag3_WSJ0_multilabel_epoch40'))
        mix_speech_classifier.load_state_dict(
            torch.load(
                'params/param_speech_123onezeroag4_WSJ0_multilabel_epoch70'))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_WSJ0_hidden3d_180'))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_WSJ0_emblayer_180'))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix101_WSJ0_attlayer_180'))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_dbag1nosum_WSJ0_hidden3d_350',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_dbag1nosum_WSJ0_emblayer_350',map_location={'cuda:1':'cuda:0'}))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix101_dbag1nosum_WSJ0_attlayer_350',map_location={'cuda:1':'cuda:0'}))

        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix2or3_db_WSJ0_hidden3d_560',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix2or3_db_WSJ0_emblayer_560',map_location={'cuda:1':'cuda:0'}))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix2or3_db_WSJ0_attlayer_560',map_location={'cuda:1':'cuda:0'}))

        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_dbag2sum_WSJ0_hidden3d_460',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_dbag2sum_WSJ0_emblayer_460',map_location={'cuda:1':'cuda:0'}))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix101_dbag2sum_WSJ0_attlayer_460',map_location={'cuda:1':'cuda:0'}))

        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_dbdropout_WSJ0_hidden3d_370',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_dbdropout_WSJ0_emblayer_370',map_location={'cuda:1':'cuda:0'}))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix101_dbdropout_WSJ0_attlayer_370',map_location={'cuda:1':'cuda:0'}))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_4db_WSJ0_hidden3d_110',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_4db_WSJ0_emblayer_110',map_location={'cuda:1':'cuda:0'}))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix101_4db_WSJ0_attlayer_110',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_classifier.load_state_dict(torch.load('params/param_speech_4lstm_multilabelloss30map_epoch440'))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix101_dbdropoutag_WSJ0_attlayer_220',map_location={'cuda:1':'cuda:0'}))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_dbdropoutag_WSJ0_hidden3d_220',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_dbdropoutag_WSJ0_emblayer_220',map_location={'cuda:1':'cuda:0'}))

        # att_speech_layer.load_state_dict(torch.load('params/param_mix2_db2dropout_WSJ0_attlayer_495',map_location={'cuda:1':'cuda:0'}))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix2_db2dropout_WSJ0_hidden3d_495',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix2_db2dropout_WSJ0_emblayer_495',map_location={'cuda:1':'cuda:0'}))
        att_speech_layer.load_state_dict(
            torch.load('params/param_mix2_db2dropout_WSJ0_attlayer_90',
                       map_location={'cuda:1': 'cuda:0'}))
        mix_hidden_layer_3d.load_state_dict(
            torch.load('params/param_mix2_db2dropout_WSJ0_hidden3d_90',
                       map_location={'cuda:1': 'cuda:0'}))
        mix_speech_multiEmbedding.load_state_dict(
            torch.load('params/param_mix2_db2dropout_WSJ0_emblayer_90',
                       map_location={'cuda:1': 'cuda:0'}))

        # att_speech_layer.load_state_dict(torch.load('params/param_mix1to3_dbdropoutag1_WSJ0_attlayer_430',map_location={'cuda:1':'cuda:0'}))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix1to3_dbdropoutag1_WSJ0_hidden3d_430',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix1to3_dbdropoutag1_WSJ0_emblayer_430',map_location={'cuda:1':'cuda:0'}))
    loss_func = torch.nn.MSELoss()  # the target label is NOT an one-hotted
    loss_multi_func = torch.nn.MSELoss(
    )  # the target label is NOT an one-hotted
    # loss_multi_func = torch.nn.L1Loss()  # the target label is NOT an one-hotted
    loss_query_class = torch.nn.CrossEntropyLoss()

    print '''Begin to calculate.'''
    for epoch_idx in range(config.MAX_EPOCH):
        if epoch_idx > 0:
            print 'SDR_SUM (len:{}) for epoch {} : {}'.format(
                SDR_SUM.shape, epoch_idx - 1, SDR_SUM.mean())
        SDR_SUM = np.array([])
        # print_memory_state(memory.memory)
        print 'SDR_SUM for epoch {}:{}'.format(epoch_idx - 1, SDR_SUM.mean())
        for batch_idx in range(config.EPOCH_SIZE):
            print '*' * 40, epoch_idx, batch_idx, '*' * 40
            train_data_gen = prepare_data('once', 'train')
            # train_data_gen=prepare_data('once','test')
            # train_data_gen=prepare_data('once','eval_test')
            train_data = train_data_gen.next()
            '''混合语音len,fre,Emb 3D表示层'''
            mix_speech_hidden = mix_hidden_layer_3d(
                Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            # 暂时关掉video部分,因为s2 s3 s4 的视频数据不全暂时
            '''Speech self Sepration 语音自分离部分'''
            mix_speech_output = mix_speech_classifier(
                Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            #从数据里得到ground truth的说话人名字和vector
            y_spk_list = [
                one.keys() for one in train_data['multi_spk_fea_list']
            ]
            y_spk_list = train_data['multi_spk_fea_list']
            y_spk_gtruth, y_map_gtruth = multi_label_vector(
                y_spk_list, dict_spk2idx)
            # 如果训练阶段使用Ground truth的分离结果作为判别
            if 1 and config.Ground_truth:
                mix_speech_output = Variable(
                    torch.from_numpy(y_map_gtruth)).cuda()
                if 0 and test_all_outputchannel:  #把输入的mask改成全1,可以用来测试输出所有的channel
                    mix_speech_output = Variable(
                        torch.ones(
                            config.BATCH_SIZE,
                            num_labels,
                        ))
                    y_map_gtruth = np.ones([config.BATCH_SIZE, num_labels])

            max_num_labels = 2
            top_k_mask_mixspeech, top_k_sort_index = top_k_mask(
                mix_speech_output, alpha=-0.5,
                top_k=max_num_labels)  #torch.Float型的
            top_k_mask_idx = [
                np.where(line == 1)[0]
                for line in top_k_mask_mixspeech.numpy()
            ]
            mix_speech_multiEmbs = mix_speech_multiEmbedding(
                top_k_mask_mixspeech,
                top_k_mask_idx)  # bs*num_labels(最多混合人个数)×Embedding的大小

            assert len(top_k_mask_idx[0]) == len(top_k_mask_idx[-1])
            top_k_num = len(top_k_mask_idx[0])

            #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
            #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了
            mix_speech_hidden_5d = mix_speech_hidden.view(
                config.BATCH_SIZE, 1, mix_speech_len, speech_fre,
                config.EMBEDDING_SIZE)
            mix_speech_hidden_5d = mix_speech_hidden_5d.expand(
                config.BATCH_SIZE, top_k_num, mix_speech_len, speech_fre,
                config.EMBEDDING_SIZE).contiguous()
            mix_speech_hidden_5d_last = mix_speech_hidden_5d.view(
                -1, mix_speech_len, speech_fre, config.EMBEDDING_SIZE)
            # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
            att_speech_layer = ATTENTION(config.EMBEDDING_SIZE, 'dot').cuda()
            att_multi_speech = att_speech_layer(
                mix_speech_hidden_5d_last,
                mix_speech_multiEmbs.view(-1, config.EMBEDDING_SIZE))
            # print att_multi_speech.size()
            att_multi_speech = att_multi_speech.view(
                config.BATCH_SIZE, top_k_num, mix_speech_len,
                speech_fre)  # bs,num_labels,len,fre这个东西
            # print att_multi_speech.size()
            multi_mask = att_multi_speech
            # top_k_mask_mixspeech_multi=top_k_mask_mixspeech.view(config.BATCH_SIZE,top_k_num,1,1).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre)
            # multi_mask=multi_mask*Variable(top_k_mask_mixspeech_multi).cuda()

            x_input_map = Variable(torch.from_numpy(
                train_data['mix_feas'])).cuda()
            # print x_input_map.size()
            x_input_map_multi = x_input_map.view(
                config.BATCH_SIZE, 1, mix_speech_len,
                speech_fre).expand(config.BATCH_SIZE, top_k_num,
                                   mix_speech_len, speech_fre)
            # predict_multi_map=multi_mask*x_input_map_multi
            predict_multi_map = multi_mask * x_input_map_multi

            bss_eval_fromGenMap(multi_mask, x_input_map, top_k_mask_mixspeech,
                                dict_idx2spk, train_data, top_k_sort_index)
            SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output/', 2))
            print 'SDR_SUM (len:{}) for epoch {} : {}'.format(
                SDR_SUM.shape, epoch_idx, SDR_SUM.mean())
示例#3
0
def eval_bss(mix_hidden_layer_3d, adjust_layer, mix_speech_classifier,
             mix_speech_multiEmbedding, att_speech_layer, loss_multi_func,
             dict_spk2idx, dict_idx2spk, num_labels, mix_speech_len,
             speech_fre):
    for i in [
            mix_speech_multiEmbedding, adjust_layer, mix_speech_classifier,
            mix_hidden_layer_3d, att_speech_layer
    ]:
        i.training = False
    print '#' * 40
    eval_data_gen = prepare_data('once', 'valid')
    SDR_SUM = np.array([])
    while True:
        print '\n\n'
        eval_data = eval_data_gen.next()
        if eval_data == False:
            break  #如果这个epoch的生成器没有数据了,直接进入下一个epoch
        '''混合语音len,fre,Emb 3D表示层'''
        mix_speech_hidden, mix_tmp_hidden = mix_hidden_layer_3d(
            Variable(torch.from_numpy(eval_data['mix_feas'])).cuda())
        # mix_tmp_hidden:[bs*T*hidden_units]
        # 暂时关掉video部分,因为s2 s3 s4 的视频数据不全暂时
        '''Speech self Sepration 语音自分离部分'''
        mix_speech_output = mix_speech_classifier(
            Variable(torch.from_numpy(eval_data['mix_feas'])).cuda())
        #从数据里得到ground truth的说话人名字和vector
        # y_spk_list=[one.keys() for one in eval_data['multi_spk_fea_list']]
        y_spk_list = eval_data['multi_spk_fea_list']
        y_spk_gtruth, y_map_gtruth = multi_label_vector(
            y_spk_list, dict_spk2idx)
        # 如果训练阶段使用Ground truth的分离结果作为判别
        if config.Ground_truth:
            mix_speech_output = Variable(torch.from_numpy(y_map_gtruth)).cuda()
            if test_all_outputchannel:  #把输入的mask改成全1,可以用来测试输出所有的channel
                mix_speech_output = Variable(
                    torch.ones(
                        config.BATCH_SIZE,
                        num_labels,
                    ))
                y_map_gtruth = np.ones([config.BATCH_SIZE, num_labels])

        top_k_mask_mixspeech = top_k_mask(mix_speech_output,
                                          alpha=0.5,
                                          top_k=num_labels)  #torch.Float型的
        top_k_mask_idx = [
            np.where(line == 1)[0] for line in top_k_mask_mixspeech.numpy()
        ]
        mix_speech_multiEmbs = mix_speech_multiEmbedding(
            top_k_mask_mixspeech,
            top_k_mask_idx)  # bs*num_labels(最多混合人个数)×Embedding的大小
        if config.is_SelfTune:
            mix_adjust = adjust_layer(mix_tmp_hidden, mix_speech_multiEmbs)
            mix_speech_multiEmbs = mix_adjust + mix_speech_multiEmbs

        assert len(top_k_mask_idx[0]) == len(top_k_mask_idx[-1])
        top_k_num = len(top_k_mask_idx[0])

        #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
        #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了
        mix_speech_hidden_5d = mix_speech_hidden.view(config.BATCH_SIZE, 1,
                                                      mix_speech_len,
                                                      speech_fre,
                                                      config.EMBEDDING_SIZE)
        mix_speech_hidden_5d = mix_speech_hidden_5d.expand(
            config.BATCH_SIZE, top_k_num, mix_speech_len, speech_fre,
            config.EMBEDDING_SIZE).contiguous()
        mix_speech_hidden_5d_last = mix_speech_hidden_5d.view(
            -1, mix_speech_len, speech_fre, config.EMBEDDING_SIZE)
        if not config.is_ComlexMask:
            mix_speech_multiEmbs = mix_speech_multiEmbs.view(
                -1, config.EMBEDDING_SIZE)
        else:
            mix_speech_multiEmbs = mix_speech_multiEmbs.view(
                -1, 2 * config.EMBEDDING_SIZE)
        att_multi_speech = att_speech_layer(mix_speech_hidden_5d_last,
                                            mix_speech_multiEmbs)
        print att_multi_speech.size()

        if not config.is_ComlexMask:
            att_multi_speech = att_multi_speech.view(
                config.BATCH_SIZE, top_k_num, mix_speech_len,
                speech_fre)  # bs,num_labels,len,fre这个东西
        else:
            att_multi_speech = att_multi_speech.view(
                config.BATCH_SIZE, top_k_num, mix_speech_len, speech_fre,
                2)  # bs,num_labels,len,fre,2这个东西
            att_multi_speech = -1 / cRM_C * torch.log(
                (cRM_k - att_multi_speech) / (cRM_k + att_multi_speech))

        multi_mask = att_multi_speech
        if not config.is_ComlexMask:
            x_input_map = Variable(torch.from_numpy(
                eval_data['mix_feas'])).cuda()
            # print x_input_map.size()
            x_input_map_multi = x_input_map.view(
                config.BATCH_SIZE, 1, mix_speech_len,
                speech_fre).expand(config.BATCH_SIZE, top_k_num,
                                   mix_speech_len, speech_fre)
            # predict_multi_map=multi_mask*x_input_map_multi
            predict_multi_map = multi_mask * x_input_map_multi

            y_multi_map = np.zeros(
                [config.BATCH_SIZE, top_k_num, mix_speech_len, speech_fre],
                dtype=np.float32)
            batch_spk_multi_dict = eval_data['multi_spk_fea_list']
            for idx, sample in enumerate(batch_spk_multi_dict):
                y_idx = sorted([dict_spk2idx[spk] for spk in sample.keys()])
                assert y_idx == list(top_k_mask_idx[idx])
                for jdx, oo in enumerate(y_idx):
                    y_multi_map[idx, jdx] = sample[dict_idx2spk[oo]]
            y_multi_map = Variable(torch.from_numpy(y_multi_map)).cuda()

            loss_multi_speech = loss_multi_func(predict_multi_map, y_multi_map)

            #各通道和为1的loss部分,应该可以更多的带来差异
            y_sum_map = Variable(
                torch.ones(config.BATCH_SIZE, mix_speech_len,
                           speech_fre)).cuda()
            predict_sum_map = torch.sum(multi_mask, 1)
            loss_multi_sum_speech = loss_multi_func(predict_sum_map, y_sum_map)
            # loss_multi_speech=loss_multi_speech #todo:以后可以研究下这个和为1的效果对比一下,暂时直接MSE效果已经很不错了。
            print 'loss 1, losssum : ', loss_multi_speech.data.cpu().numpy(
            ), loss_multi_sum_speech.data.cpu().numpy()
            loss_multi_speech = loss_multi_speech + 0.5 * loss_multi_sum_speech
            print 'training multi-abs norm this batch:', torch.abs(
                y_multi_map - predict_multi_map).norm().data.cpu().numpy()
            print 'loss:', loss_multi_speech.data.cpu().numpy()
            bss_eval(predict_multi_map, y_multi_map, top_k_mask_idx,
                     dict_idx2spk, eval_data)

        else:
            x_input_map = Variable(torch.from_numpy(
                eval_data['mix_mag'])).cuda()  # bs,len,fre,2
            x_input_map_multi = x_input_map.view(
                config.BATCH_SIZE, 1, mix_speech_len, speech_fre,
                2).expand(config.BATCH_SIZE, top_k_num, mix_speech_len,
                          speech_fre, 2)

            multi_mask_real = multi_mask[:, :, :, :, 0]
            multi_mask_fake = multi_mask[:, :, :, :, 1]
            x_input_map_real = x_input_map_multi[:, :, :, :, 0]
            x_input_map_fake = x_input_map_multi[:, :, :, :, 1]
            predict_map_real = multi_mask_real * x_input_map_real - multi_mask_fake * x_input_map_fake
            predict_map_fake = multi_mask_real * x_input_map_fake + multi_mask_fake * x_input_map_real

            y_multi_map = np.zeros(
                [config.BATCH_SIZE, top_k_num, mix_speech_len, speech_fre, 2],
                dtype=np.float32)
            batch_spk_multi_dict = eval_data['multi_spk_fea_list']
            for idx, sample in enumerate(batch_spk_multi_dict):
                y_idx = sorted([dict_spk2idx[spk] for spk in sample.keys()])
                assert y_idx == list(top_k_mask_idx[idx])
                for jdx, oo in enumerate(y_idx):
                    y_multi_map[idx, jdx] = sample[dict_idx2spk[oo]]
            y_multi_map = Variable(torch.from_numpy(y_multi_map)).cuda()
            y_multi_map_real = y_multi_map[:, :, :, :, 0]
            y_multi_map_fake = y_multi_map[:, :, :, :, 1]

            loss_multi_speech_real = loss_multi_func(predict_map_real,
                                                     y_multi_map_real)
            loss_multi_speech_fake = loss_multi_func(predict_map_fake,
                                                     y_multi_map_fake)
            loss_multi_speech = loss_multi_speech_fake + loss_multi_speech_real

            print 'loss_real:', loss_multi_speech_real.data.cpu().numpy()
            print 'loss_fake:', loss_multi_speech_fake.data.cpu().numpy()
            print 'loss:', loss_multi_speech.data.cpu().numpy()
            bss_eval_cRM(predict_map_real, predict_map_fake, y_multi_map,
                         top_k_mask_idx, dict_idx2spk, eval_data)

        SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output/', 2))
        print 'SDR_aver_now:', SDR_SUM.mean()

    SDR_aver = SDR_SUM.mean()
    print 'SDR_SUM (len:{}) for epoch eval : {}'.format(
        SDR_SUM.shape, SDR_aver)
    print '#' * 40
def main():
    print('go to model')
    print '*' * 80

    spk_global_gen=prepare_data(mode='global',train_or_test='train') #写一个假的数据生成,可以用来写模型先
    global_para=spk_global_gen.next()
    print global_para
    spk_all_list,dict_spk2idx,dict_idx2spk,mix_speech_len,speech_fre,total_frames,spk_num_total=global_para
    del spk_global_gen
    num_labels=len(spk_all_list)

    # data_generator=prepare_data('once','train')
    # data_generator=prepare_data_fake(train_or_test='train',num_labels=num_labels) #写一个假的数据生成,可以用来写模型先

    #此处顺序是 mix_speechs.shape,mix_feas.shape,aim_fea.shape,aim_spkid.shape,query.shape
    #一个例子:(5, 17040) (5, 134, 129) (5, 134, 129) (5,) (5, 32, 400, 300, 3)
    # datasize=prepare_datasize(data_generator)
    # mix_speech_len,speech_fre,total_frames,spk_num_total,video_size=datasize
    print 'Begin to build the maim model for Multi_Modal Cocktail Problem.'
    # data=data_generator.next()

    # This part is to build the 3D mix speech embedding maps.
    mix_hidden_layer_3d=MIX_SPEECH(speech_fre,mix_speech_len).cuda()
    mix_speech_classifier=MIX_SPEECH_classifier(speech_fre,mix_speech_len,num_labels).cuda()
    mix_speech_multiEmbedding=SPEECH_EMBEDDING(num_labels,config.EMBEDDING_SIZE,spk_num_total+config.UNK_SPK_SUPP).cuda()
    print mix_hidden_layer_3d
    print mix_speech_classifier
    # mix_speech_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(data[1])).cuda())

    # mix_speech_output=mix_speech_classifier(Variable(torch.from_numpy(data[1])).cuda())
    # 技巧:alpha0的时候,就是选出top_k,top_k很大的时候,就是选出来大于alpha的
    # top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=config.ALPHA,top_k=config.MAX_MIX)
    # top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=config.ALPHA,top_k=3)
    # print top_k_mask_mixspeech
    # mix_speech_multiEmbs=mix_speech_multiEmbedding(top_k_mask_mixspeech) # bs*num_labels(最多混合人个数)×Embedding的大小
    # mix_speech_multiEmbs=mix_speech_multiEmbedding(Variable(torch.from_numpy(top_k_mask_mixspeech),requires_grad=False).cuda()) # bs*num_labels(最多混合人个数)×Embedding的大小

    # 需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
    # 把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了 
    # mix_speech_hidden_5d=mix_speech_hidden.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
    # mix_speech_hidden_5d=mix_speech_hidden_5d.expand(config.BATCH_SIZE,num_labels,mix_speech_len,speech_fre,config.EMBEDDING_SIZE).contiguous()
    # mix_speech_hidden_5d=mix_speech_hidden_5d.view(-1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
    # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
    # att_multi_speech=att_speech_layer(mix_speech_hidden_5d,mix_speech_multiEmbs.view(-1,config.EMBEDDING_SIZE))
    # print att_multi_speech.size()
    # att_multi_speech=att_multi_speech.view(config.BATCH_SIZE,num_labels,mix_speech_len,speech_fre,-1)
    # print att_multi_speech.size()


    # This part is to conduct the video inputs.
    # query_video_layer=VIDEO_QUERY(total_frames,config.VideoSize,spk_num_total).cuda()
    query_video_layer=None
    # print query_video_layer
    # query_video_output,xx=query_video_layer(Variable(torch.from_numpy(data[4])))

    # This part is to conduct the memory.
    # hidden_size=(config.HIDDEN_UNITS)
    hidden_size=(config.EMBEDDING_SIZE)
    # x=torch.arange(0,24).view(2,3,4)
    # y=torch.ones([2,4])
    att_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
    att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
    # att=ATTENTION(4,'align')
    # mask=att(x,y)#bs*max_len

    # del data_generator
    # del data

    optimizer = torch.optim.Adam([{'params':mix_hidden_layer_3d.parameters()},
                                 {'params':mix_speech_multiEmbedding.parameters()},
                                 {'params':mix_speech_classifier.parameters()},
                                 # {'params':query_video_layer.lstm_layer.parameters()},
                                 # {'params':query_video_layer.dense.parameters()},
                                 # {'params':query_video_layer.Linear.parameters()},
                                 {'params':att_layer.parameters()},
                                 {'params':att_speech_layer.parameters()},
                                 # ], lr=0.02,momentum=0.9)
                                 ], lr=0.0002)
    if 0 and config.Load_param:
        # query_video_layer.load_state_dict(torch.load('param_video_layer_19'))
        # mix_speech_classifier.load_state_dict(torch.load('params/param_speech_multilabel_epoch249'))
        mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_WSJ0_hidden3d_180'))
        mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_WSJ0_emblayer_180'))
        att_speech_layer.load_state_dict(torch.load('params/param_mix101_WSJ0_attlayer_180'))
    loss_func = torch.nn.MSELoss()  # the target label is NOT an one-hotted
    loss_multi_func = torch.nn.MSELoss()  # the target label is NOT an one-hotted
    # loss_multi_func = torch.nn.L1Loss()  # the target label is NOT an one-hotted
    loss_query_class=torch.nn.CrossEntropyLoss()

    print '''Begin to calculate.'''
    for epoch_idx in range(config.MAX_EPOCH):
        if epoch_idx%50==0:
            for ee in optimizer.param_groups:
                ee['lr']/=2
        if epoch_idx>0:
            print 'SDR_SUM (len:{}) for epoch {} : '.format(SDR_SUM.shape,epoch_idx-1,SDR_SUM.mean())
        SDR_SUM=np.array([])
        # print_memory_state(memory.memory)
        print 'SDR_SUM for epoch {}:{}'.format(epoch_idx - 1, SDR_SUM.mean())
        for batch_idx in range(config.EPOCH_SIZE):
            print '*' * 40,epoch_idx,batch_idx,'*'*40
            train_data_gen=prepare_data('once','train')
            # train_data_gen=prepare_data('once','test')
            train_data=train_data_gen.next()

            '''混合语音len,fre,Emb 3D表示层'''
            mix_speech_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            # 暂时关掉video部分,因为s2 s3 s4 的视频数据不全暂时

            '''Speech self Sepration 语音自分离部分'''
            mix_speech_output=mix_speech_classifier(Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            #从数据里得到ground truth的说话人名字和vector
            # y_spk_list=[one.keys() for one in train_data['multi_spk_fea_list']]
            y_spk_list= train_data['multi_spk_fea_list']
            y_spk_gtruth,y_map_gtruth=multi_label_vector(y_spk_list,dict_spk2idx)
            # 如果训练阶段使用Ground truth的分离结果作为判别
            if config.Ground_truth:
                mix_speech_output=Variable(torch.from_numpy(y_map_gtruth)).cuda()
                if test_all_outputchannel: #把输入的mask改成全1,可以用来测试输出所有的channel
                    mix_speech_output=Variable(torch.ones(config.BATCH_SIZE,num_labels,))
                    y_map_gtruth=np.ones([config.BATCH_SIZE,num_labels])

            top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=0.5,top_k=num_labels) #torch.Float型的
            top_k_mask_idx=[np.where(line==1)[0] for line in top_k_mask_mixspeech.numpy()]
            mix_speech_multiEmbs=mix_speech_multiEmbedding(top_k_mask_mixspeech,top_k_mask_idx) # bs*num_labels(最多混合人个数)×Embedding的大小

            assert len(top_k_mask_idx[0])==len(top_k_mask_idx[-1])
            top_k_num=len(top_k_mask_idx[0])

            #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
            #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了 
            mix_speech_hidden_5d=mix_speech_hidden.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
            mix_speech_hidden_5d=mix_speech_hidden_5d.expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre,config.EMBEDDING_SIZE).contiguous()
            mix_speech_hidden_5d_last=mix_speech_hidden_5d.view(-1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
            # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
            att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'dot').cuda()
            att_multi_speech=att_speech_layer(mix_speech_hidden_5d_last,mix_speech_multiEmbs.view(-1,config.EMBEDDING_SIZE))
            # print att_multi_speech.size()
            att_multi_speech=att_multi_speech.view(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre) # bs,num_labels,len,fre这个东西
            # print att_multi_speech.size()
            multi_mask=att_multi_speech
            # top_k_mask_mixspeech_multi=top_k_mask_mixspeech.view(config.BATCH_SIZE,top_k_num,1,1).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre)
            # multi_mask=multi_mask*Variable(top_k_mask_mixspeech_multi).cuda()

            x_input_map=Variable(torch.from_numpy(train_data['mix_feas'])).cuda()
            # print x_input_map.size()
            x_input_map_multi=x_input_map.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre)
            # predict_multi_map=multi_mask*x_input_map_multi
            predict_multi_map=multi_mask*x_input_map_multi
            if 0 and batch_idx%100==0:
                print multi_mask
            # print predict_multi_map

            y_multi_map=np.zeros([config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre],dtype=np.float32)
            batch_spk_multi_dict=train_data['multi_spk_fea_list']
            for idx,sample in enumerate(batch_spk_multi_dict):
                y_idx=sorted([dict_spk2idx[spk] for spk in sample.keys()])
                assert y_idx==list(top_k_mask_idx[idx])
                for jdx,oo in enumerate(y_idx):
                    y_multi_map[idx,jdx]=sample[dict_idx2spk[oo]]
            y_multi_map= Variable(torch.from_numpy(y_multi_map)).cuda()

            loss_multi_speech=loss_multi_func(predict_multi_map,y_multi_map)

            #各通道和为1的loss部分,应该可以更多的带来差异
            y_sum_map=Variable(torch.ones(config.BATCH_SIZE,mix_speech_len,speech_fre)).cuda()
            predict_sum_map=torch.sum(multi_mask,1)
            loss_multi_sum_speech=loss_multi_func(predict_sum_map,y_sum_map)
            # loss_multi_speech=loss_multi_speech #todo:以后可以研究下这个和为1的效果对比一下,暂时直接MSE效果已经很不错了。
            print 'loss 1, losssum : ',loss_multi_speech.data.cpu().numpy(),loss_multi_sum_speech.data.cpu().numpy()
            loss_multi_speech=loss_multi_speech+0.5*loss_multi_sum_speech
            print 'training multi-abs norm this batch:',torch.abs(y_multi_map-predict_multi_map).norm().data.cpu().numpy()
            print 'loss:',loss_multi_speech.data.cpu().numpy()

            if 1 or batch_idx==config.EPOCH_SIZE-1:
                bss_eval(predict_multi_map,y_multi_map,top_k_mask_idx,dict_idx2spk,train_data)
                SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output/', 2))


            optimizer.zero_grad()   # clear gradients for next train
            loss_multi_speech.backward()         # backpropagation, compute gradients
            optimizer.step()        # apply gradients

            if 1 and epoch_idx>80 and epoch_idx%10==0 and batch_idx==config.EPOCH_SIZE-1:
                torch.save(mix_speech_multiEmbedding.state_dict(),'params/param_mix_{}_emblayer_{}'.format(config.DATASET,epoch_idx))
                torch.save(mix_hidden_layer_3d.state_dict(),'params/param_mix_{}_hidden3d_{}'.format(config.DATASET,epoch_idx))
                torch.save(att_speech_layer.state_dict(),'params/param_mix_{}_attlayer_{}'.format(config.DATASET,epoch_idx))
示例#5
0
def main():
    print('go to model')
    print '*' * 80

    spk_global_gen = prepare_data(mode='global',
                                  train_or_test='train')  #写一个假的数据生成,可以用来写模型先
    global_para = spk_global_gen.next()
    print global_para
    spk_all_list, dict_spk2idx, dict_idx2spk, mix_speech_len, speech_fre, total_frames, spk_num_total = global_para
    del spk_global_gen
    num_labels = len(spk_all_list)

    # data_generator=prepare_data('once','train')
    # data_generator=prepare_data_fake(train_or_test='train',num_labels=num_labels) #写一个假的数据生成,可以用来写模型先

    #此处顺序是 mix_speechs.shape,mix_feas.shape,aim_fea.shape,aim_spkid.shape,query.shape
    #一个例子:(5, 17040) (5, 134, 129) (5, 134, 129) (5,) (5, 32, 400, 300, 3)
    # datasize=prepare_datasize(data_generator)
    # mix_speech_len,speech_fre,total_frames,spk_num_total,video_size=datasize
    print 'Begin to build the maim model for Multi_Modal Cocktail Problem.'
    # data=data_generator.next()

    # This part is to build the 3D mix speech embedding maps.
    mix_hidden_layer_3d = MIX_SPEECH(speech_fre, mix_speech_len).cuda()
    mix_speech_classifier = MIX_SPEECH_classifier(speech_fre, mix_speech_len,
                                                  num_labels).cuda()
    mix_speech_multiEmbedding = SPEECH_EMBEDDING(
        num_labels, config.EMBEDDING_SIZE,
        spk_num_total + config.UNK_SPK_SUPP).cuda()
    print mix_hidden_layer_3d
    print mix_speech_classifier
    # mix_speech_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(data[1])).cuda())

    # mix_speech_output=mix_speech_classifier(Variable(torch.from_numpy(data[1])).cuda())
    # 技巧:alpha0的时候,就是选出top_k,top_k很大的时候,就是选出来大于alpha的
    # top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=config.ALPHA,top_k=config.MAX_MIX)
    # top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=config.ALPHA,top_k=3)
    # print top_k_mask_mixspeech
    # mix_speech_multiEmbs=mix_speech_multiEmbedding(top_k_mask_mixspeech) # bs*num_labels(最多混合人个数)×Embedding的大小
    # mix_speech_multiEmbs=mix_speech_multiEmbedding(Variable(torch.from_numpy(top_k_mask_mixspeech),requires_grad=False).cuda()) # bs*num_labels(最多混合人个数)×Embedding的大小

    # 需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
    # 把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了
    # mix_speech_hidden_5d=mix_speech_hidden.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
    # mix_speech_hidden_5d=mix_speech_hidden_5d.expand(config.BATCH_SIZE,num_labels,mix_speech_len,speech_fre,config.EMBEDDING_SIZE).contiguous()
    # mix_speech_hidden_5d=mix_speech_hidden_5d.view(-1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
    # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
    # att_multi_speech=att_speech_layer(mix_speech_hidden_5d,mix_speech_multiEmbs.view(-1,config.EMBEDDING_SIZE))
    # print att_multi_speech.size()
    # att_multi_speech=att_multi_speech.view(config.BATCH_SIZE,num_labels,mix_speech_len,speech_fre,-1)
    # print att_multi_speech.size()

    # This part is to conduct the video inputs.
    query_video_layer = VIDEO_QUERY(total_frames, config.VideoSize,
                                    spk_num_total).cuda()
    # print query_video_layer
    # query_video_output,xx=query_video_layer(Variable(torch.from_numpy(data[4])))

    # This part is to conduct the memory.
    # hidden_size=(config.HIDDEN_UNITS)
    hidden_size = (config.EMBEDDING_SIZE)
    memory = MEMORY(spk_num_total + config.UNK_SPK_SUPP, hidden_size)
    memory.register_spklist(spk_all_list)  #把spk_list注册进空的memory里面去

    # Memory function test.
    print 'memory all spkid:', memory.get_all_spkid()
    # print memory.get_image_num('Unknown_id')
    # print memory.get_video_vector('Unknown_id')
    # print memory.add_video('Unknown_id',Variable(torch.ones(300)))

    # This part is to test the ATTENTION methond from query(~) to mix_speech
    # x=torch.arange(0,24).view(2,3,4)
    # y=torch.ones([2,4])
    att_layer = ATTENTION(config.EMBEDDING_SIZE, 'align').cuda()
    att_speech_layer = ATTENTION(config.EMBEDDING_SIZE, 'align').cuda()
    # att=ATTENTION(4,'align')
    # mask=att(x,y)#bs*max_len

    # del data_generator
    # del data

    optimizer = torch.optim.Adam(
        [
            {
                'params': mix_hidden_layer_3d.parameters()
            },
            {
                'params': mix_speech_multiEmbedding.parameters()
            },
            {
                'params': mix_speech_classifier.parameters()
            },
            # {'params':query_video_layer.lstm_layer.parameters()},
            # {'params':query_video_layer.dense.parameters()},
            # {'params':query_video_layer.Linear.parameters()},
            {
                'params': att_layer.parameters()
            },
            {
                'params': att_speech_layer.parameters()
            },
            # ], lr=0.02,momentum=0.9)
        ],
        lr=0.0002)
    if 0 and config.Load_param:
        # query_video_layer.load_state_dict(torch.load('param_video_layer_19'))
        mix_speech_classifier.load_state_dict(
            torch.load('params/param_speech_multilabel_epoch249'))
    loss_func = torch.nn.MSELoss()  # the target label is NOT an one-hotted
    loss_multi_func = torch.nn.MSELoss(
    )  # the target label is NOT an one-hotted
    # loss_multi_func = torch.nn.L1Loss()  # the target label is NOT an one-hotted
    loss_query_class = torch.nn.CrossEntropyLoss()

    print '''Begin to calculate.'''
    for epoch_idx in range(config.MAX_EPOCH):
        print_memory_state(memory.memory)
        for batch_idx in range(config.EPOCH_SIZE):
            print '*' * 40, epoch_idx, batch_idx, '*' * 40
            train_data_gen = prepare_data('once', 'train')
            train_data = train_data_gen.next()
            '''混合语音len,fre,Emb 3D表示层'''
            mix_speech_hidden = mix_hidden_layer_3d(
                Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            # 暂时关掉video部分,因为s2 s3 s4 的视频数据不全暂时
            '''Speech self Sepration 语音自分离部分'''
            mix_speech_output = mix_speech_classifier(
                Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            #从数据里得到ground truth的说话人名字和vector
            y_spk_list = [
                one.keys() for one in train_data['multi_spk_fea_list']
            ]
            y_spk_gtruth, y_map_gtruth = multi_label_vector(
                y_spk_list, dict_spk2idx)
            # 如果训练阶段使用Ground truth的分离结果作为判别
            if config.Ground_truth:
                mix_speech_output = Variable(
                    torch.from_numpy(y_map_gtruth)).cuda()

            top_k_mask_mixspeech = top_k_mask(mix_speech_output,
                                              alpha=0.5,
                                              top_k=num_labels)  #torch.Float型的
            mix_speech_multiEmbs = mix_speech_multiEmbedding(
                top_k_mask_mixspeech)  # bs*num_labels(最多混合人个数)×Embedding的大小

            #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
            #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了
            mix_speech_hidden_5d = mix_speech_hidden.view(
                config.BATCH_SIZE, 1, mix_speech_len, speech_fre,
                config.EMBEDDING_SIZE)
            mix_speech_hidden_5d = mix_speech_hidden_5d.expand(
                config.BATCH_SIZE, num_labels, mix_speech_len, speech_fre,
                config.EMBEDDING_SIZE).contiguous()
            mix_speech_hidden_5d_last = mix_speech_hidden_5d.view(
                -1, mix_speech_len, speech_fre, config.EMBEDDING_SIZE)
            # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
            att_speech_layer = ATTENTION(config.EMBEDDING_SIZE, 'dot').cuda()
            att_multi_speech = att_speech_layer(
                mix_speech_hidden_5d_last,
                mix_speech_multiEmbs.view(-1, config.EMBEDDING_SIZE))
            # print att_multi_speech.size()
            att_multi_speech = att_multi_speech.view(
                config.BATCH_SIZE, num_labels, mix_speech_len,
                speech_fre)  # bs,num_labels,len,fre这个东西
            # print att_multi_speech.size()
            multi_mask = att_multi_speech
            top_k_mask_mixspeech_multi = top_k_mask_mixspeech.view(
                config.BATCH_SIZE, num_labels, 1,
                1).expand(config.BATCH_SIZE, num_labels, mix_speech_len,
                          speech_fre)
            multi_mask = multi_mask * Variable(
                top_k_mask_mixspeech_multi).cuda()

            x_input_map = Variable(torch.from_numpy(
                train_data['mix_feas'])).cuda()
            # print x_input_map.size()
            x_input_map_multi = x_input_map.view(
                config.BATCH_SIZE, 1, mix_speech_len,
                speech_fre).expand(config.BATCH_SIZE, num_labels,
                                   mix_speech_len, speech_fre)
            predict_multi_map = multi_mask * x_input_map_multi
            if batch_idx % 100 == 0:
                print multi_mask
            # print predict_multi_map

            y_multi_map = np.zeros(
                [config.BATCH_SIZE, num_labels, mix_speech_len, speech_fre],
                dtype=np.float32)
            batch_spk_multi_dict = train_data['multi_spk_fea_list']
            for idx, sample in enumerate(batch_spk_multi_dict):
                for spk in sample.keys():
                    y_multi_map[idx, dict_spk2idx[spk]] = sample[spk]
            y_multi_map = Variable(torch.from_numpy(y_multi_map)).cuda()

            loss_multi_speech = loss_multi_func(predict_multi_map, y_multi_map)

            #各通道和为1的loss部分,应该可以更多的带来差异
            y_sum_map = Variable(
                torch.ones(config.BATCH_SIZE, mix_speech_len,
                           speech_fre)).cuda()
            predict_sum_map = torch.sum(predict_multi_map, 1)
            loss_multi_sum_speech = loss_multi_func(predict_sum_map, y_sum_map)
            loss_multi_speech = loss_multi_speech  #todo:以后可以研究下这个和为1的效果对比一下,暂时直接MSE效果已经很不错了。
            # loss_multi_speech=loss_multi_speech+0.5*loss_multi_sum_speech

            if batch_idx == config.EPOCH_SIZE - 1:
                bss_eval(predict_multi_map, y_multi_map, y_map_gtruth,
                         dict_idx2spk, train_data)

            print 'training multi-abs norm this batch:', torch.abs(
                y_multi_map - predict_multi_map).norm().data.cpu().numpy()
            print 'loss:', loss_multi_speech.data.cpu().numpy()
            optimizer.zero_grad()  # clear gradients for next train
            loss_multi_speech.backward()  # backpropagation, compute gradients
            optimizer.step()  # apply gradients

            if 1 and epoch_idx > 20 and epoch_idx % 10 == 0 and batch_idx == config.EPOCH_SIZE - 1:
                torch.save(
                    mix_speech_multiEmbedding.state_dict(),
                    'params/param_mix_speech_emblayer_{}'.format(epoch_idx))
                torch.save(
                    mix_hidden_layer_3d.state_dict(),
                    'params/param_mix_speech_hidden3d_{}'.format(epoch_idx))
                torch.save(
                    att_speech_layer.state_dict(),
                    'params/param_mix_speech_attlayer_{}'.format(epoch_idx))

            # print 'Parameter history:'
            # for pa_gen in [{'params':mix_hidden_layer_3d.parameters()},
            #                                  {'params':mix_speech_multiEmbedding.parameters()},
            #                                  {'params':mix_hidden_layer_3d.parameters()},
            #                                  {'params':att_speech_layer.parameters()},
            #                                  {'params':att_layer.parameters()},
            #                                  {'params':mix_speech_classifier.parameters()},
            #                                  ]:
            #     print pa_gen['params'].next().data.cpu().numpy()[0]

            continue
            1 / 0
            '''视频刺激 Sepration 部分'''
            # try:
            #     query_video_output,query_video_hidden=query_video_layer(Variable(torch.from_numpy(train_data[4])).cuda())
            # except RuntimeError:
            #     print 'RuntimeError here.'+'#'*30
            #     continue

            query_video_output, query_video_hidden = query_video_layer(
                Variable(torch.from_numpy(train_data[4])).cuda())
            if config.Comm_with_Memory:
                #TODO:query更新这里要再检查一遍,最好改成函数,现在有点丑陋。
                aim_idx_FromVideoQuery = torch.max(query_video_output,
                                                   dim=1)[1]  #返回最大的参数
                aim_spk_batch = [
                    dict_idx2spk[int(idx.data.cpu().numpy())]
                    for idx in aim_idx_FromVideoQuery
                ]
                print 'Query class result:', aim_spk_batch, 'p:', query_video_output.data.cpu(
                ).numpy()

                for idx, aim_spk in enumerate(aim_spk_batch):
                    batch_vector = torch.stack(
                        [memory.get_video_vector(aim_spk)])
                    memory.add_video(aim_spk, query_video_hidden[idx])
                query_video_hidden = query_video_hidden + Variable(
                    batch_vector)
                query_video_hidden = query_video_hidden / torch.sum(
                    query_video_hidden * query_video_hidden, 0)
                y_class = Variable(torch.from_numpy(
                    np.array([
                        dict_spk2idx[spk] for spk in train_data['aim_spkname']
                    ])),
                                   requires_grad=False).cuda()
                print y_class
                loss_video_class = loss_query_class(query_video_output,
                                                    y_class)

            mask = att_layer(mix_speech_hidden,
                             query_video_hidden)  #bs*max_len*fre

            predict_map = mask * Variable(
                torch.from_numpy(train_data['mix_feas'])).cuda()
            y_map = Variable(torch.from_numpy(train_data['aim_fea'])).cuda()
            print 'training abs norm this batch:', torch.abs(
                y_map - predict_map).norm().data.cpu().numpy()
            loss_all = loss_func(predict_map, y_map)
            if 0 and config.Save_param:
                torch.save(query_video_layer.state_dict(),
                           'param_video_layer_19_forS1S5')

            if 0 and epoch_idx < 20:
                loss = loss_video_class
                if epoch_idx % 1 == 0 and batch_idx == config.EPOCH_SIZE - 1:
                    torch.save(query_video_layer.state_dict(),
                               'param_video_layer_19_forS1S5')
            else:
                # loss=loss_all+0.1*loss_video_class
                loss = loss_all
            optimizer.zero_grad()  # clear gradients for next train
            loss.backward(
                retain_graph=True)  # backpropagation, compute gradients
            optimizer.step()  # apply gradients
def main():
    print('go to model')
    print '*' * 80

    spk_global_gen=prepare_data(mode='global',train_or_test='train') #写一个假的数据生成,可以用来写模型先
    global_para=spk_global_gen.next()
    print global_para
    spk_all_list,dict_spk2idx,dict_idx2spk,mix_speech_len,speech_fre,total_frames,spk_num_total,batch_total=global_para
    del spk_global_gen
    num_labels=len(spk_all_list)

    print 'Begin to build the maim model for Multi_Modal Cocktail Problem.'

    # This part is to build the 3D mix speech embedding maps.
    mix_hidden_layer_3d=MIX_SPEECH(speech_fre,mix_speech_len).cuda()
    mix_speech_classifier=MIX_SPEECH_classifier(speech_fre,mix_speech_len,num_labels).cuda()
    mix_speech_multiEmbedding=SPEECH_EMBEDDING(num_labels,config.EMBEDDING_SIZE,spk_num_total+config.UNK_SPK_SUPP).cuda()
    print mix_hidden_layer_3d
    print mix_speech_classifier
    print mix_speech_multiEmbedding
    att_layer=ATTENTION(config.EMBEDDING_SIZE,'dot').cuda()
    att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'dot').cuda()
    adjust_layer=ADDJUST(2*config.HIDDEN_UNITS,config.EMBEDDING_SIZE)
    dis_layer=Discriminator().cuda()
    print att_speech_layer
    print att_speech_layer.mode
    print adjust_layer
    print dis_layer
    lr_data=0.0002
    optimizer = torch.optim.Adam([{'params':mix_hidden_layer_3d.parameters()},
                                 {'params':mix_speech_multiEmbedding.parameters()},
                                 {'params':mix_speech_classifier.parameters()},
                                 {'params':adjust_layer.parameters()},
                                 {'params':att_speech_layer.parameters()},
                                 {'params':dis_layer.parameters()},
                                 ], lr=lr_data)
    if 1 and config.Load_param:
        class_dict=torch.load('params/param_speech_2mix3lstm_best',map_location={'cuda:3':'cuda:0'})
        for key in class_dict.keys():
            if 'cnn' in key:
                class_dict.pop(key)
        mix_speech_classifier.load_state_dict(class_dict)
        mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mixdotadjust4lstmdot_WSJ0_hidden3d_125',map_location={'cuda:1':'cuda:0'}))
        mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mixdotadjust4lstmdot_WSJ0_emblayer_125',map_location={'cuda:1':'cuda:0'}))
        att_speech_layer.load_state_dict(torch.load('params/param_mixdotadjust4lstmdot_WSJ0_attlayer_125',map_location={'cuda:1':'cuda:0'}))
        adjust_layer.load_state_dict(torch.load('params/param_mixdotadjust4lstmdot_WSJ0_adjlayer_125',map_location={'cuda:1':'cuda:0'}))
    loss_func = torch.nn.MSELoss()  # the target label is NOT an one-hotted
    loss_multi_func = torch.nn.MSELoss()  # the target label is NOT an one-hotted
    # loss_multi_func = torch.nn.L1Loss()  # the target label is NOT an one-hotted
    loss_dis_class=torch.nn.MSELoss()

    lrs.send({
        'title': 'TDAA classifier',
        'batch_size':config.BATCH_SIZE,
        'batch_total':batch_total,
        'epoch_size':config.EPOCH_SIZE,
        'loss func':loss_func.__str__(),
        'initial lr':lr_data
    })

    print '''Begin to calculate.'''
    for epoch_idx in range(config.MAX_EPOCH):
        if epoch_idx%10==0:
            for ee in optimizer.param_groups:
                if ee['lr']>=1e-7:
                    ee['lr']/=2
                lr_data=ee['lr']
        lrs.send('lr',lr_data)
        if epoch_idx>0:
            print 'SDR_SUM (len:{}) for epoch {} : '.format(SDR_SUM.shape,epoch_idx-1,SDR_SUM.mean())
        SDR_SUM=np.array([])
        train_data_gen=prepare_data('once','train')
        # train_data_gen=prepare_data('once','test')
        while 0 and True:
            train_data=train_data_gen.next()
            if train_data==False:
                break #如果这个epoch的生成器没有数据了,直接进入下一个epoch
            '''混合语音len,fre,Emb 3D表示层'''
            mix_speech_hidden,mix_tmp_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            # mix_tmp_hidden:[bs*T*hidden_units]
            # 暂时关掉video部分,因为s2 s3 s4 的视频数据不全暂时

            '''Speech self Sepration 语音自分离部分'''
            mix_speech_output=mix_speech_classifier(Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            #从数据里得到ground truth的说话人名字和vector
            # y_spk_list=[one.keys() for one in train_data['multi_spk_fea_list']]
            y_spk_list= train_data['multi_spk_fea_list']
            y_spk_gtruth,y_map_gtruth=multi_label_vector(y_spk_list,dict_spk2idx)
            # 如果训练阶段使用Ground truth的分离结果作为判别
            if config.Ground_truth:
                mix_speech_output=Variable(torch.from_numpy(y_map_gtruth)).cuda()
                if test_all_outputchannel: #把输入的mask改成全1,可以用来测试输出所有的channel
                    mix_speech_output=Variable(torch.ones(config.BATCH_SIZE,num_labels,))
                    y_map_gtruth=np.ones([config.BATCH_SIZE,num_labels])

            top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=0.5,top_k=num_labels) #torch.Float型的
            top_k_mask_idx=[np.where(line==1)[0] for line in top_k_mask_mixspeech.numpy()]
            mix_speech_multiEmbs=mix_speech_multiEmbedding(top_k_mask_mixspeech,top_k_mask_idx) # bs*num_labels(最多混合人个数)×Embedding的大小
            mix_adjust=adjust_layer(mix_tmp_hidden,mix_speech_multiEmbs)
            mix_speech_multiEmbs=mix_adjust+mix_speech_multiEmbs

            assert len(top_k_mask_idx[0])==len(top_k_mask_idx[-1])
            top_k_num=len(top_k_mask_idx[0])

            #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
            #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了 
            mix_speech_hidden_5d=mix_speech_hidden.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
            mix_speech_hidden_5d=mix_speech_hidden_5d.expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre,config.EMBEDDING_SIZE).contiguous()
            mix_speech_hidden_5d_last=mix_speech_hidden_5d.view(-1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
            att_multi_speech=att_speech_layer(mix_speech_hidden_5d_last,mix_speech_multiEmbs.view(-1,config.EMBEDDING_SIZE))
            print att_multi_speech.size()
            att_multi_speech=att_multi_speech.view(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre) # bs,num_labels,len,fre这个东西
            # print att_multi_speech.size()
            multi_mask=att_multi_speech
            # top_k_mask_mixspeech_multi=top_k_mask_mixspeech.view(config.BATCH_SIZE,top_k_num,1,1).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre)
            # multi_mask=multi_mask*Variable(top_k_mask_mixspeech_multi).cuda()

            x_input_map=Variable(torch.from_numpy(train_data['mix_feas'])).cuda()
            # print x_input_map.size()
            x_input_map_multi=x_input_map.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre)
            # predict_multi_map=multi_mask*x_input_map_multi
            predict_multi_map=multi_mask*x_input_map_multi

            y_multi_map=np.zeros([config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre],dtype=np.float32)
            batch_spk_multi_dict=train_data['multi_spk_fea_list']
            for idx,sample in enumerate(batch_spk_multi_dict):
                y_idx=sorted([dict_spk2idx[spk] for spk in sample.keys()])
                assert y_idx==list(top_k_mask_idx[idx])
                for jdx,oo in enumerate(y_idx):
                    y_multi_map[idx,jdx]=sample[dict_idx2spk[oo]]
            y_multi_map= Variable(torch.from_numpy(y_multi_map)).cuda()

            loss_multi_speech=loss_multi_func(predict_multi_map,y_multi_map)

            score_true=dis_layer(y_multi_map)
            score_false=dis_layer(predict_multi_map)
            acc_true=torch.sum(score_true>0.5).data.cpu().numpy()/float(score_true.size()[0])
            acc_false=torch.sum(score_false<0.5).data.cpu().numpy()/float(score_true.size()[0])
            acc_dis=(acc_false+acc_true)/2
            print 'acc for dis:(ture,false,aver)',acc_true,acc_false,acc_dis

            loss_dis_true=loss_dis_class(score_true,Variable(torch.ones(config.BATCH_SIZE*top_k_num,1)).cuda())
            loss_dis_false=loss_dis_class(score_false,Variable(torch.zeros(config.BATCH_SIZE*top_k_num,1)).cuda())
            loss_dis=loss_dis_true+loss_dis_false
            print 'loss for dis:(ture,false)',loss_dis_true.data.cpu().numpy(),loss_dis_false.data.cpu().numpy()
            optimizer.zero_grad()   # clear gradients for next train
            loss_dis.backward(retain_graph=True)         # backpropagation, compute gradients
            optimizer.step()        # apply gradients

            #各通道和为1的loss部分,应该可以更多的带来差异
            y_sum_map=Variable(torch.ones(config.BATCH_SIZE,mix_speech_len,speech_fre)).cuda()
            predict_sum_map=torch.sum(multi_mask,1)
            loss_multi_sum_speech=loss_multi_func(predict_sum_map,y_sum_map)
            # loss_multi_speech=loss_multi_speech #todo:以后可以研究下这个和为1的效果对比一下,暂时直接MSE效果已经很不错了。
            print 'loss 1, losssum : ',loss_multi_speech.data.cpu().numpy(),loss_multi_sum_speech.data.cpu().numpy()
            lrs.send('loss mask:',loss_multi_speech.data.cpu()[0])
            lrs.send('loss sum:',loss_multi_sum_speech.data.cpu()[0])
            loss_multi_speech=loss_multi_speech+0.5*loss_multi_sum_speech
            print 'training multi-abs norm this batch:',torch.abs(y_multi_map-predict_multi_map).norm().data.cpu().numpy()
            print 'loss:',loss_multi_speech.data.cpu().numpy()

            loss_dis_false=loss_dis_class(score_false,Variable(torch.ones(config.BATCH_SIZE*top_k_num,1)).cuda())
            loss_multi_speech=loss_multi_speech+loss_dis_false

            optimizer.zero_grad()   # clear gradients for next train
            loss_multi_speech.backward()         # backpropagation, compute gradients
            optimizer.step()        # apply gradients

        if 1 and epoch_idx >= 10 and epoch_idx % 5 == 0:
            # torch.save(mix_speech_multiEmbedding.state_dict(),'params/param_mixalignag_{}_emblayer_{}'.format(config.DATASET,epoch_idx))
            # torch.save(mix_hidden_layer_3d.state_dict(),'params/param_mixalignag_{}_hidden3d_{}'.format(config.DATASET,epoch_idx))
            # torch.save(att_speech_layer.state_dict(),'params/param_mixalignag_{}_attlayer_{}'.format(config.DATASET,epoch_idx))
            torch.save(mix_speech_multiEmbedding.state_dict(),
                       'params/param_mix{}ag_{}_emblayer_{}'.format(att_speech_layer.mode, config.DATASET, epoch_idx))
            torch.save(mix_hidden_layer_3d.state_dict(),
                       'params/param_mix{}ag_{}_hidden3d_{}'.format(att_speech_layer.mode, config.DATASET, epoch_idx))
            torch.save(att_speech_layer.state_dict(),
                       'params/param_mix{}ag_{}_attlayer_{}'.format(att_speech_layer.mode, config.DATASET, epoch_idx))
            torch.save(adjust_layer.state_dict(),
                       'params/param_mix{}ag_{}_adjlayer_{}'.format(att_speech_layer.mode, config.DATASET, epoch_idx))
            torch.save(dis_layer.state_dict(),
                       'params/param_mix{}ag_{}_dislayer_{}'.format(att_speech_layer.mode, config.DATASET, epoch_idx))
        if 1 and epoch_idx % 3 == 0:
            eval_bss(mix_hidden_layer_3d,adjust_layer, mix_speech_classifier, mix_speech_multiEmbedding, att_speech_layer,
                     loss_multi_func, dict_spk2idx, dict_idx2spk, num_labels, mix_speech_len, speech_fre)
def eval_bss(mix_hidden_layer_3d,adjust_layer,mix_speech_classifier,mix_speech_multiEmbedding,att_speech_layer,
             loss_multi_func,dict_spk2idx,dict_idx2spk,num_labels,mix_speech_len,speech_fre):
    for i in [mix_speech_multiEmbedding,adjust_layer,mix_speech_classifier,mix_hidden_layer_3d,att_speech_layer]:
        i.training=False
    print '#' * 40
    eval_data_gen=prepare_data('once','valid')
    SDR_SUM=np.array([])
    while True:
        print '\n\n'
        eval_data=eval_data_gen.next()
        if eval_data==False:
            break #如果这个epoch的生成器没有数据了,直接进入下一个epoch
        '''混合语音len,fre,Emb 3D表示层'''
        mix_speech_hidden,mix_tmp_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(eval_data['mix_feas'])).cuda())
        # 暂时关掉video部分,因为s2 s3 s4 的视频数据不全暂时

        '''Speech self Sepration 语音自分离部分'''
        mix_speech_output=mix_speech_classifier(Variable(torch.from_numpy(eval_data['mix_feas'])).cuda())

        if not test_mode:
            y_spk_list= eval_data['multi_spk_fea_list']
            y_spk_gtruth,y_map_gtruth=multi_label_vector(y_spk_list,dict_spk2idx)
            # 如果训练阶段使用Ground truth的分离结果作为判别
            if not test_mode and config.Ground_truth:
                mix_speech_output=Variable(torch.from_numpy(y_map_gtruth)).cuda()
                if test_all_outputchannel: #把输入的mask改成全1,可以用来测试输出所有的channel
                    mix_speech_output=Variable(torch.ones(config.BATCH_SIZE,num_labels,))
                    y_map_gtruth=np.ones([config.BATCH_SIZE,num_labels])

        if test_mode:
            num_labels=2
            alpha0=-0.5
        else:
            alpha0=0.5
        top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=alpha0,top_k=num_labels) #torch.Float型的
        top_k_mask_idx=[np.where(line==1)[0] for line in top_k_mask_mixspeech.numpy()]
        print 'Predict spk list:',print_spk_name(dict_idx2spk,top_k_mask_idx)
        mix_speech_multiEmbs=mix_speech_multiEmbedding(top_k_mask_mixspeech,top_k_mask_idx) # bs*num_labels(最多混合人个数)×Embedding的大小
        mix_adjust=adjust_layer(mix_tmp_hidden,mix_speech_multiEmbs)
        mix_speech_multiEmbs=mix_adjust+mix_speech_multiEmbs

        assert len(top_k_mask_idx[0])==len(top_k_mask_idx[-1])
        top_k_num=len(top_k_mask_idx[0])

        #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
        #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了 
        mix_speech_hidden_5d=mix_speech_hidden.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
        mix_speech_hidden_5d=mix_speech_hidden_5d.expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre,config.EMBEDDING_SIZE).contiguous()
        mix_speech_hidden_5d_last=mix_speech_hidden_5d.view(-1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
        # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
        # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'dot').cuda()
        att_multi_speech=att_speech_layer(mix_speech_hidden_5d_last,mix_speech_multiEmbs.view(-1,config.EMBEDDING_SIZE))
        # print att_multi_speech.size()
        att_multi_speech=att_multi_speech.view(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre) # bs,num_labels,len,fre这个东西
        # print att_multi_speech.size()
        multi_mask=att_multi_speech
        # top_k_mask_mixspeech_multi=top_k_mask_mixspeech.view(config.BATCH_SIZE,top_k_num,1,1).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre)
        # multi_mask=multi_mask*Variable(top_k_mask_mixspeech_multi).cuda()

        x_input_map=Variable(torch.from_numpy(eval_data['mix_feas'])).cuda()
        # print x_input_map.size()
        x_input_map_multi=x_input_map.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre)
        # predict_multi_map=multi_mask*x_input_map_multi
        predict_multi_map=multi_mask*x_input_map_multi


        y_multi_map=np.zeros([config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre],dtype=np.float32)
        batch_spk_multi_dict=eval_data['multi_spk_fea_list']
        if test_mode:
            for iiii in range(config.BATCH_SIZE):
                y_multi_map[iiii]=np.array(batch_spk_multi_dict[iiii].values())
        else:
            for idx,sample in enumerate(batch_spk_multi_dict):
                y_idx=sorted([dict_spk2idx[spk] for spk in sample.keys()])
                if not test_mode:
                    assert y_idx==list(top_k_mask_idx[idx])
                for jdx,oo in enumerate(y_idx):
                    y_multi_map[idx,jdx]=sample[dict_idx2spk[oo]]
        y_multi_map= Variable(torch.from_numpy(y_multi_map)).cuda()

        loss_multi_speech=loss_multi_func(predict_multi_map,y_multi_map)

        #各通道和为1的loss部分,应该可以更多的带来差异
        y_sum_map=Variable(torch.ones(config.BATCH_SIZE,mix_speech_len,speech_fre)).cuda()
        predict_sum_map=torch.sum(multi_mask,1)
        loss_multi_sum_speech=loss_multi_func(predict_sum_map,y_sum_map)
        # loss_multi_speech=loss_multi_speech #todo:以后可以研究下这个和为1的效果对比一下,暂时直接MSE效果已经很不错了。
        print 'loss 1 eval, losssum eval : ',loss_multi_speech.data.cpu().numpy(),loss_multi_sum_speech.data.cpu().numpy()
        lrs.send('loss mask eval:',loss_multi_speech.data.cpu()[0])
        lrs.send('loss sum eval:',loss_multi_sum_speech.data.cpu()[0])
        loss_multi_speech=loss_multi_speech+0.5*loss_multi_sum_speech
        print 'evaling multi-abs norm this eval batch:',torch.abs(y_multi_map-predict_multi_map).norm().data.cpu().numpy()
        print 'loss:',loss_multi_speech.data.cpu().numpy()
        bss_eval(predict_multi_map,y_multi_map,top_k_mask_idx,dict_idx2spk,eval_data)
        SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output/', 2))
        print 'SDR_aver_now:',SDR_SUM.mean()

    SDR_aver=SDR_SUM.mean()
    print 'SDR_SUM (len:{}) for epoch eval : '.format(SDR_SUM.shape)
    lrs.send('SDR eval aver',SDR_aver)
    print '#'*40