def eval_bss(model, loss_multi_func, mix_speech_len, speech_fre):
    model.training = False
    print '#' * 40
    eval_data_gen = prepare_data('once', 'valid')
    SDR_SUM = np.array([])
    while True:
        print '\n\n'
        eval_data = eval_data_gen.next()
        if eval_data == False:
            break  #如果这个epoch的生成器没有数据了,直接进入下一个epoch

        top_k_num = eval_data['top_k']  #对于这个batch的top-k
        print 'top-k this batch:', top_k_num

        mix_speech_orignial = Variable(torch.from_numpy(
            eval_data['mix_feas'])).cuda()
        mix_speech = torch.transpose(mix_speech_orignial, 1, 3)
        mix_speech = torch.transpose(mix_speech, 2, 3)
        # (2L, 301L, 257L, 2L) >>> (2L, 2L, 301L, 257L)
        print 'mix_speech_shape:', mix_speech.size()

        images_query = Variable(
            torch.from_numpy(
                convert2numpy(eval_data['multi_video_fea_list'],
                              top_k_num))).cuda()  #大小bs,topk,75,3,299,299
        y_map = convert2numpy(eval_data['multi_spk_fea_list'],
                              top_k_num)  #最终的map
        print 'final map shape:', y_map.shape
        predict_multi_masks = model(mix_speech, images_query)
        print 'predict results shape:', predict_multi_masks.size(
        )  #(2L, topk, 301L, 257L, 2L)

        mix_speech_multi=mix_speech_orignial.view(config.BATCH_SIZE,1,speech_fre,mix_speech_len,2) \
            .expand(config.BATCH_SIZE,top_k_num,speech_fre,mix_speech_len,2)
        # (2L, 301L, 257L, 2L) >> (2L, topk,301L, 257L, 2L)

        predict_multi_masks_real = predict_multi_masks[:, :, :, :, 0]
        predict_multi_masks_fake = predict_multi_masks[:, :, :, :, 1]
        mix_speech_real = mix_speech_multi[:, :, :, :, 0]
        mix_speech_fake = mix_speech_multi[:, :, :, :, 1]
        y_map_real = Variable(torch.from_numpy(y_map[:, :, :, :, 0])).cuda()
        y_map_fake = Variable(torch.from_numpy(y_map[:, :, :, :, 1])).cuda()

        predict_real = predict_multi_masks_real * mix_speech_real - predict_multi_masks_fake * mix_speech_fake
        predict_fake = predict_multi_masks_real * mix_speech_fake + predict_multi_masks_fake * mix_speech_real
        print 'predict real/fake size:', predict_real.size()

        loss_real = loss_multi_func(predict_real, y_map_real)  #/top_k_num
        loss_fake = loss_multi_func(predict_fake, y_map_fake)  #/top_k_num
        loss_all = loss_real + loss_fake
        print 'loss:', loss_real.data[0], loss_fake.data[0]
        bss_eval_cRM(predict_real, predict_fake, y_map, eval_data)

        SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output23234/', 2))
        print 'SDR_aver_now:', SDR_SUM.mean()

    SDR_aver = SDR_SUM.mean()
    print 'SDR_SUM (len:{}) for epoch eval : {}'.format(
        SDR_SUM.shape, SDR_aver)
    print '#' * 40
Ejemplo n.º 2
0
def eval(epoch):
    # config.batch_size=1
    model.eval()
    # print '\n\n测试的时候请设置config里的batch_size为1!!!please set the batch_size as 1'
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    test_or_valid = 'test'
    test_or_valid = 'valid'
    # test_or_valid = 'train'
    print(('Test or valid:', test_or_valid))
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX, config.MAX_MIX)
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    # for iii in range(2000):
    while True:
        print(('-' * 30))
        eval_data = next(eval_data_gen)
        if eval_data == False:
            print(('SDR_aver_eval_epoch:', SDR_SUM.mean()))
            print(('SDRi_aver_eval_epoch:', SDRi_SUM.mean()))
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_feas']))

        # raw_tgt = [sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']]
        raw_tgt= eval_data['batch_order']
        feas_tgt = models.rank_feas(raw_tgt, eval_data['multi_spk_fea_list'])  # 这里是目标的图谱

        top_k = len(raw_tgt[0])
        # 要保证底下这几个都是longTensor(长整数)
        # tgt = Variable(torch.from_numpy(np.array([[0]+[dict_spk2idx[spk] for spk in spks]+[dict_spk2idx['<EOS>']] for spks in raw_tgt],dtype=np.int))).transpose(0,1) #转换成数字,然后前后加开始和结束符号。
        tgt = Variable(torch.from_numpy(np.array([[0,1,2,102] for __ in range(config.batch_size)], dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。

        src_len = Variable(torch.LongTensor(config.batch_size).zero_() + mix_speech_len).unsqueeze(0)
        tgt_len = Variable(torch.LongTensor([len(one_spk) for one_spk in eval_data['multi_spk_fea_list']])).unsqueeze(0)
        # tgt_len = Variable(torch.LongTensor(config.batch_size).zero_()+len(eval_data['multi_spk_fea_list'][0])).unsqueeze(0)
        if config.WFM:
            tmp_size = feas_tgt.size()
            assert len(tmp_size) == 3
            feas_tgt_square = feas_tgt * feas_tgt
            feas_tgt_sum_square = torch.sum(feas_tgt_square, dim=0, keepdim=True).expand(tmp_size)
            WFM_mask = feas_tgt_square / (feas_tgt_sum_square + 1e-15)

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()
            if config.WFM:
                WFM_mask = WFM_mask.cuda()

        # try:
        if 1 and len(opt.gpus) > 1:
            samples,  predicted_masks = model.module.pit_sample(src, src_len, dict_spk2idx, tgt,
                                                                                    beam_size=config.beam_size)
        else:
            samples,  predicted_masks = model.pit_sample(src, src_len, dict_spk2idx, tgt,
                                                                             beam_size=config.beam_size)
        samples=samples.max(2)[1].data.cpu().numpy()
        # except:
        #     continue

        # '''
        # expand the raw mixed-features to topk_max channel.
        src = src.transpose(0, 1)
        siz = src.size()
        assert len(siz) == 3
        # if samples[0][-1] != dict_spk2idx['<EOS>']:
        #     print '*'*40+'\nThe model is far from good. End the evaluation.\n'+'*'*40
        #     break
        topk_max = len(samples[0]) - 1
        x_input_map_multi = torch.unsqueeze(src, 1).expand(siz[0], topk_max, siz[1], siz[2])
        if 1 and config.WFM:
            feas_tgt = x_input_map_multi.data * WFM_mask

        if test_or_valid != 'test':
            if 1 and len(opt.gpus) > 1:
                ss_loss = model.module.separation_loss(x_input_map_multi, predicted_masks, feas_tgt, )
            else:
                ss_loss = model.separation_loss(x_input_map_multi, predicted_masks, feas_tgt)
            print(('loss for ss,this batch:', ss_loss.cpu().item()))
            lera.log({
                'ss_loss_' + test_or_valid: ss_loss.cpu().item(),
            })
            del ss_loss

        # '''''
        if 1 and batch_idx <= (500 / config.batch_size):  # only the former batches counts the SDR
            predicted_maps = predicted_masks * x_input_map_multi
            # predicted_maps=Variable(feas_tgt)
            utils.bss_eval2(config, predicted_maps, eval_data['multi_spk_fea_list'], raw_tgt, eval_data,
                            dst='batch_output_test')
            del predicted_maps, predicted_masks, x_input_map_multi
            try:
                sdr_aver_batch, sdri_aver_batch=  bss_test.cal('batch_output_test/')
                SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
                SDRi_SUM = np.append(SDRi_SUM, sdri_aver_batch)
            except(AssertionError):
                print('Errors in calculating the SDR')
            print(('SDR_aver_now:', SDR_SUM.mean()))
            print(('SDRi_aver_now:', SDRi_SUM.mean()))
            lera.log({'SDR sample'+test_or_valid: SDR_SUM.mean()})
            lera.log({'SDRi sample'+test_or_valid: SDRi_SUM.mean()})
            writer.add_scalars('scalar/loss',{'SDR_sample_'+test_or_valid:sdr_aver_batch},updates)
            # raw_input('Press any key to continue......')
        elif batch_idx == (200 / config.batch_size) + 1 and SDR_SUM.mean() > best_SDR:  # only record the best SDR once.
            print(('Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean())))
            best_SDR = SDR_SUM.mean()
            # save_model(log_path+'checkpoint_bestSDR{}.pt'.format(best_SDR))

        # '''
        candidate += [convertToLabels(dict_idx2spk, s, dict_spk2idx['<EOS>']) for s in samples]
        # source += raw_src
        reference += raw_tgt
        print(('samples:', samples))
        print(('can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:], reference[-1 * config.batch_size:])))
        # alignments += [align for align in alignment]
        batch_idx += 1

        result = utils.eval_metrics(reference, candidate, dict_spk2idx, log_path)
        print(('hamming_loss: %.8f | micro_f1: %.4f |recall: %.4f | precision: %.4f'
                   % (result['hamming_loss'], result['micro_f1'], result['micro_recall'], result['micro_precision'], )))

    score = {}
    result = utils.eval_metrics(reference, candidate, dict_spk2idx, log_path)
    logging_csv([e, updates, result['hamming_loss'], \
                 result['micro_f1'], result['micro_precision'], result['micro_recall'],SDR_SUM.mean()])
    print(('hamming_loss: %.8f | micro_f1: %.4f'
          % (result['hamming_loss'], result['micro_f1'])))
    score['hamming_loss'] = result['hamming_loss']
    score['micro_f1'] = result['micro_f1']
    1/0
    return score
Ejemplo n.º 3
0
def eval_bss(mix_hidden_layer_3d, adjust_layer, mix_speech_classifier,
             mix_speech_multiEmbedding, att_speech_layer, loss_multi_func,
             dict_spk2idx, dict_idx2spk, num_labels, mix_speech_len,
             speech_fre):
    for i in [
            mix_speech_multiEmbedding, adjust_layer, mix_speech_classifier,
            mix_hidden_layer_3d, att_speech_layer
    ]:
        i.training = False
    print '#' * 40
    eval_data_gen = prepare_data('once', 'valid')
    SDR_SUM = np.array([])
    while True:
        print '\n\n'
        eval_data = eval_data_gen.next()
        if eval_data == False:
            break  #如果这个epoch的生成器没有数据了,直接进入下一个epoch
        '''混合语音len,fre,Emb 3D表示层'''
        mix_speech_hidden, mix_tmp_hidden = mix_hidden_layer_3d(
            Variable(torch.from_numpy(eval_data['mix_feas'])).cuda())
        # mix_tmp_hidden:[bs*T*hidden_units]
        # 暂时关掉video部分,因为s2 s3 s4 的视频数据不全暂时
        '''Speech self Sepration 语音自分离部分'''
        mix_speech_output = mix_speech_classifier(
            Variable(torch.from_numpy(eval_data['mix_feas'])).cuda())
        #从数据里得到ground truth的说话人名字和vector
        # y_spk_list=[one.keys() for one in eval_data['multi_spk_fea_list']]
        y_spk_list = eval_data['multi_spk_fea_list']
        y_spk_gtruth, y_map_gtruth = multi_label_vector(
            y_spk_list, dict_spk2idx)
        # 如果训练阶段使用Ground truth的分离结果作为判别
        if config.Ground_truth:
            mix_speech_output = Variable(torch.from_numpy(y_map_gtruth)).cuda()
            if test_all_outputchannel:  #把输入的mask改成全1,可以用来测试输出所有的channel
                mix_speech_output = Variable(
                    torch.ones(
                        config.BATCH_SIZE,
                        num_labels,
                    ))
                y_map_gtruth = np.ones([config.BATCH_SIZE, num_labels])

        top_k_mask_mixspeech = top_k_mask(mix_speech_output,
                                          alpha=0.5,
                                          top_k=num_labels)  #torch.Float型的
        top_k_mask_idx = [
            np.where(line == 1)[0] for line in top_k_mask_mixspeech.numpy()
        ]
        mix_speech_multiEmbs = mix_speech_multiEmbedding(
            top_k_mask_mixspeech,
            top_k_mask_idx)  # bs*num_labels(最多混合人个数)×Embedding的大小
        if config.is_SelfTune:
            mix_adjust = adjust_layer(mix_tmp_hidden, mix_speech_multiEmbs)
            mix_speech_multiEmbs = mix_adjust + mix_speech_multiEmbs

        assert len(top_k_mask_idx[0]) == len(top_k_mask_idx[-1])
        top_k_num = len(top_k_mask_idx[0])

        #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
        #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了
        mix_speech_hidden_5d = mix_speech_hidden.view(config.BATCH_SIZE, 1,
                                                      mix_speech_len,
                                                      speech_fre,
                                                      config.EMBEDDING_SIZE)
        mix_speech_hidden_5d = mix_speech_hidden_5d.expand(
            config.BATCH_SIZE, top_k_num, mix_speech_len, speech_fre,
            config.EMBEDDING_SIZE).contiguous()
        mix_speech_hidden_5d_last = mix_speech_hidden_5d.view(
            -1, mix_speech_len, speech_fre, config.EMBEDDING_SIZE)
        if not config.is_ComlexMask:
            mix_speech_multiEmbs = mix_speech_multiEmbs.view(
                -1, config.EMBEDDING_SIZE)
        else:
            mix_speech_multiEmbs = mix_speech_multiEmbs.view(
                -1, 2 * config.EMBEDDING_SIZE)
        att_multi_speech = att_speech_layer(mix_speech_hidden_5d_last,
                                            mix_speech_multiEmbs)
        print att_multi_speech.size()

        if not config.is_ComlexMask:
            att_multi_speech = att_multi_speech.view(
                config.BATCH_SIZE, top_k_num, mix_speech_len,
                speech_fre)  # bs,num_labels,len,fre这个东西
        else:
            att_multi_speech = att_multi_speech.view(
                config.BATCH_SIZE, top_k_num, mix_speech_len, speech_fre,
                2)  # bs,num_labels,len,fre,2这个东西
            att_multi_speech = -1 / cRM_C * torch.log(
                (cRM_k - att_multi_speech) / (cRM_k + att_multi_speech))

        multi_mask = att_multi_speech
        if not config.is_ComlexMask:
            x_input_map = Variable(torch.from_numpy(
                eval_data['mix_feas'])).cuda()
            # print x_input_map.size()
            x_input_map_multi = x_input_map.view(
                config.BATCH_SIZE, 1, mix_speech_len,
                speech_fre).expand(config.BATCH_SIZE, top_k_num,
                                   mix_speech_len, speech_fre)
            # predict_multi_map=multi_mask*x_input_map_multi
            predict_multi_map = multi_mask * x_input_map_multi

            y_multi_map = np.zeros(
                [config.BATCH_SIZE, top_k_num, mix_speech_len, speech_fre],
                dtype=np.float32)
            batch_spk_multi_dict = eval_data['multi_spk_fea_list']
            for idx, sample in enumerate(batch_spk_multi_dict):
                y_idx = sorted([dict_spk2idx[spk] for spk in sample.keys()])
                assert y_idx == list(top_k_mask_idx[idx])
                for jdx, oo in enumerate(y_idx):
                    y_multi_map[idx, jdx] = sample[dict_idx2spk[oo]]
            y_multi_map = Variable(torch.from_numpy(y_multi_map)).cuda()

            loss_multi_speech = loss_multi_func(predict_multi_map, y_multi_map)

            #各通道和为1的loss部分,应该可以更多的带来差异
            y_sum_map = Variable(
                torch.ones(config.BATCH_SIZE, mix_speech_len,
                           speech_fre)).cuda()
            predict_sum_map = torch.sum(multi_mask, 1)
            loss_multi_sum_speech = loss_multi_func(predict_sum_map, y_sum_map)
            # loss_multi_speech=loss_multi_speech #todo:以后可以研究下这个和为1的效果对比一下,暂时直接MSE效果已经很不错了。
            print 'loss 1, losssum : ', loss_multi_speech.data.cpu().numpy(
            ), loss_multi_sum_speech.data.cpu().numpy()
            loss_multi_speech = loss_multi_speech + 0.5 * loss_multi_sum_speech
            print 'training multi-abs norm this batch:', torch.abs(
                y_multi_map - predict_multi_map).norm().data.cpu().numpy()
            print 'loss:', loss_multi_speech.data.cpu().numpy()
            bss_eval(predict_multi_map, y_multi_map, top_k_mask_idx,
                     dict_idx2spk, eval_data)

        else:
            x_input_map = Variable(torch.from_numpy(
                eval_data['mix_mag'])).cuda()  # bs,len,fre,2
            x_input_map_multi = x_input_map.view(
                config.BATCH_SIZE, 1, mix_speech_len, speech_fre,
                2).expand(config.BATCH_SIZE, top_k_num, mix_speech_len,
                          speech_fre, 2)

            multi_mask_real = multi_mask[:, :, :, :, 0]
            multi_mask_fake = multi_mask[:, :, :, :, 1]
            x_input_map_real = x_input_map_multi[:, :, :, :, 0]
            x_input_map_fake = x_input_map_multi[:, :, :, :, 1]
            predict_map_real = multi_mask_real * x_input_map_real - multi_mask_fake * x_input_map_fake
            predict_map_fake = multi_mask_real * x_input_map_fake + multi_mask_fake * x_input_map_real

            y_multi_map = np.zeros(
                [config.BATCH_SIZE, top_k_num, mix_speech_len, speech_fre, 2],
                dtype=np.float32)
            batch_spk_multi_dict = eval_data['multi_spk_fea_list']
            for idx, sample in enumerate(batch_spk_multi_dict):
                y_idx = sorted([dict_spk2idx[spk] for spk in sample.keys()])
                assert y_idx == list(top_k_mask_idx[idx])
                for jdx, oo in enumerate(y_idx):
                    y_multi_map[idx, jdx] = sample[dict_idx2spk[oo]]
            y_multi_map = Variable(torch.from_numpy(y_multi_map)).cuda()
            y_multi_map_real = y_multi_map[:, :, :, :, 0]
            y_multi_map_fake = y_multi_map[:, :, :, :, 1]

            loss_multi_speech_real = loss_multi_func(predict_map_real,
                                                     y_multi_map_real)
            loss_multi_speech_fake = loss_multi_func(predict_map_fake,
                                                     y_multi_map_fake)
            loss_multi_speech = loss_multi_speech_fake + loss_multi_speech_real

            print 'loss_real:', loss_multi_speech_real.data.cpu().numpy()
            print 'loss_fake:', loss_multi_speech_fake.data.cpu().numpy()
            print 'loss:', loss_multi_speech.data.cpu().numpy()
            bss_eval_cRM(predict_map_real, predict_map_fake, y_multi_map,
                         top_k_mask_idx, dict_idx2spk, eval_data)

        SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output/', 2))
        print 'SDR_aver_now:', SDR_SUM.mean()

    SDR_aver = SDR_SUM.mean()
    print 'SDR_SUM (len:{}) for epoch eval : {}'.format(
        SDR_SUM.shape, SDR_aver)
    print '#' * 40
Ejemplo n.º 4
0
def eval(epoch):
    # config.batch_size=1
    model.eval()
    # print '\n\n测试的时候请设置config里的batch_size为1!!!please set the batch_size as 1'
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    test_or_valid = 'test'
    # test_or_valid = 'valid'
    print(('Test or valid:', test_or_valid))
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX,
                                 config.MAX_MIX)
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    # for iii in range(2000):
    while True:
        print(('-' * 30))
        eval_data = next(eval_data_gen)
        if eval_data == False:
            print(('SDR_aver_eval_epoch:', SDR_SUM.mean()))
            print(('SDRi_aver_eval_epoch:', SDRi_SUM.mean()))
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_feas']))

        # raw_tgt = [sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']]
        raw_tgt = eval_data['batch_order']
        feas_tgt = models.rank_feas(
            raw_tgt, eval_data['multi_spk_fea_list'])  # 这里是目标的图谱

        top_k = len(raw_tgt[0])
        # 要保证底下这几个都是longTensor(长整数)
        # tgt = Variable(torch.from_numpy(np.array([[0]+[dict_spk2idx[spk] for spk in spks]+[dict_spk2idx['<EOS>']] for spks in raw_tgt],dtype=np.int))).transpose(0,1) #转换成数字,然后前后加开始和结束符号。
        # tgt = Variable(torch.from_numpy(np.array([[0,1,2,102] for __ in range(config.batch_size)], dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。
        # tgt = Variable(torch.from_numpy(np.array([[0,1,2,3,102] for __ in range(config.batch_size)], dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。
        tgt = Variable(
            torch.from_numpy(
                np.array([
                    list(range(top_k + 1)) + [102]
                    for __ in range(config.batch_size)
                ],
                         dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。

        padded_mixture, mixture_lengths, padded_source = eval_data['tas_zip']
        padded_mixture = torch.from_numpy(padded_mixture).float()
        mixture_lengths = torch.from_numpy(mixture_lengths)
        padded_source = torch.from_numpy(padded_source).float()

        padded_mixture = padded_mixture.cuda().transpose(0, 1)
        mixture_lengths = mixture_lengths.cuda()
        padded_source = padded_source.cuda()

        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor([
                len(one_spk) for one_spk in eval_data['multi_spk_fea_list']
            ])).unsqueeze(0)
        # tgt_len = Variable(torch.LongTensor(config.batch_size).zero_()+len(eval_data['multi_spk_fea_list'][0])).unsqueeze(0)
        if config.WFM:
            tmp_size = feas_tgt.size()
            assert len(tmp_size) == 3
            feas_tgt_square = feas_tgt * feas_tgt
            feas_tgt_sum_square = torch.sum(feas_tgt_square,
                                            dim=0,
                                            keepdim=True).expand(tmp_size)
            WFM_mask = feas_tgt_square / (feas_tgt_sum_square + 1e-15)

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()
            if config.WFM:
                WFM_mask = WFM_mask.cuda()

        if 1 and len(opt.gpus) > 1:
            outputs, pred, targets, multi_mask, dec_enc_attn_list = model(
                src,
                src_len,
                tgt,
                tgt_len,
                dict_spk2idx,
                None,
                mix_wav=padded_mixture
            )  # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
        else:
            outputs, pred, targets, multi_mask, dec_enc_attn_list = model(
                src,
                src_len,
                tgt,
                tgt_len,
                dict_spk2idx,
                None,
                mix_wav=padded_mixture
            )  # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
        samples = list(
            pred.view(config.batch_size, top_k + 1,
                      -1).max(2)[1].data.cpu().numpy())
        '''

        if 1 and len(opt.gpus) > 1:
            samples,  predicted_masks = model.module.beam_sample(src, src_len, dict_spk2idx, tgt, config.beam_size,None,padded_mixture)
        else:
            samples,  predicted_masks = model.beam_sample(src, src_len, dict_spk2idx, tgt, config.beam_size, None, padded_mixture)
            multi_mask = predicted_masks
            samples=[samples]
        # except:
        #     continue

        # '''
        # expand the raw mixed-features to topk_max channel.
        src = src.transpose(0, 1)
        siz = src.size()
        assert len(siz) == 3
        # if samples[0][-1] != dict_spk2idx['<EOS>']:
        #     print '*'*40+'\nThe model is far from good. End the evaluation.\n'+'*'*40
        #     break
        topk_max = top_k
        x_input_map_multi = torch.unsqueeze(src,
                                            1).expand(siz[0], topk_max, siz[1],
                                                      siz[2])
        multi_mask = multi_mask.transpose(0, 1)

        if test_or_valid != 'test':
            if config.use_tas:
                if 1 and len(opt.gpus) > 1:
                    ss_loss, pmt_list, max_snr_idx, *__ = model.module.separation_tas_loss(
                        padded_mixture, multi_mask, padded_source,
                        mixture_lengths)
                else:
                    ss_loss, pmt_list, max_snr_idx, *__ = model.separation_tas_loss(
                        padded_mixture, multi_mask, padded_source,
                        mixture_lengths)
            print(('loss for ss,this batch:', ss_loss.cpu().item()))
            lera.log({
                'ss_loss_' + test_or_valid: ss_loss.cpu().item(),
            })
            del ss_loss

        # '''''
        if 1 and batch_idx <= (500 / config.batch_size
                               ):  # only the former batches counts the SDR
            utils.bss_eval_tas(config,
                               multi_mask,
                               eval_data['multi_spk_fea_list'],
                               raw_tgt,
                               eval_data,
                               dst=log_path + 'batch_output/')
            del multi_mask, x_input_map_multi
            try:
                sdr_aver_batch, sdri_aver_batch = bss_test.cal(log_path +
                                                               'batch_output/')
                SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
                SDRi_SUM = np.append(SDRi_SUM, sdri_aver_batch)
            except (AssertionError):
                print('Errors in calculating the SDR')
            print(('SDR_aver_now:', SDR_SUM.mean()))
            print(('SDRi_aver_now:', SDRi_SUM.mean()))
            lera.log({'SDR sample' + test_or_valid: SDR_SUM.mean()})
            lera.log({'SDRi sample' + test_or_valid: SDRi_SUM.mean()})
            writer.add_scalars('scalar/loss',
                               {'SDR_sample_' + test_or_valid: sdr_aver_batch},
                               updates)
            # raw_input('Press any key to continue......')
        elif batch_idx == (200 / config.batch_size) + 1 and SDR_SUM.mean(
        ) > best_SDR:  # only record the best SDR once.
            print(('Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean())))
            best_SDR = SDR_SUM.mean()
            # save_model(log_path+'checkpoint_bestSDR{}.pt'.format(best_SDR))
        '''
        import matplotlib.pyplot as plt
        ax = plt.gca()
        ax.invert_yaxis()

        raw_src=models.rank_feas(raw_tgt,eval_data['multi_spk_fea_list'])
        att_idx=0
        att =dec_enc_attn_list.data.cpu().numpy()[:,att_idx] # head,topk,T
        for spk in range(3):
            xx=att[:,spk]
            plt.matshow(xx.reshape(8,1,-1).repeat(50,1).reshape(-1,751), cmap=plt.cm.hot, vmin=0,vmax=0.05)
            plt.colorbar()
            plt.savefig(log_path+'batch_output/'+'spk_{}.png'.format(spk))
            plt.matshow(xx.sum(0).reshape(1, 1, -1).repeat(50, 1).reshape(-1, 751), cmap=plt.cm.hot, vmin=0, vmax=0.05)
            plt.colorbar()
            plt.savefig(log_path + 'batch_output/' + 'spk_sum_{}.png'.format(spk))
        for head in range(8):
            xx=att[head]
            plt.matshow(xx.reshape(3,1,-1).repeat(100,1).reshape(-1,751), cmap=plt.cm.hot, vmin=0,vmax=0.05)
            plt.colorbar()
            plt.savefig(log_path+'batch_output/'+'head_{}.png'.format(head))
        plt.matshow(raw_src[att_idx*2+0].transpose(0,1), cmap=plt.cm.hot, vmin=0,vmax=2)
        plt.colorbar()
        plt.savefig(log_path+'batch_output/'+'source0.png')
        plt.matshow(raw_src[att_idx*2+1].transpose(0,1), cmap=plt.cm.hot, vmin=0,vmax=2)
        plt.colorbar()
        plt.savefig(log_path+'batch_output/'+'source1.png')
        # '''
        candidate += [
            convertToLabels(dict_idx2spk, s, dict_spk2idx['<EOS>'])
            for s in samples
        ]
        # source += raw_src
        reference += raw_tgt
        print(('samples:', samples))
        print(('can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:],
                                         reference[-1 * config.batch_size:])))
        # alignments += [align for align in alignment]
        batch_idx += 1

        result = utils.eval_metrics(reference, candidate, dict_spk2idx,
                                    log_path)
        print((
            'hamming_loss: %.8f | micro_f1: %.4f |recall: %.4f | precision: %.4f'
            % (
                result['hamming_loss'],
                result['micro_f1'],
                result['micro_recall'],
                result['micro_precision'],
            )))

    score = {}
    result = utils.eval_metrics(reference, candidate, dict_spk2idx, log_path)
    logging_csv([e, updates, result['hamming_loss'], \
                 result['micro_f1'], result['micro_precision'], result['micro_recall'],SDR_SUM.mean()])
    print(('hamming_loss: %.8f | micro_f1: %.4f' %
           (result['hamming_loss'], result['micro_f1'])))
    score['hamming_loss'] = result['hamming_loss']
    score['micro_f1'] = result['micro_f1']
    1 / 0
    return score
def main():
    print('go to model')
    print '*' * 80

    spk_global_gen=prepare_data(mode='global',train_or_test='train') #写一个假的数据生成,可以用来写模型先
    global_para=spk_global_gen.next()
    print global_para
    spk_all_list,dict_spk2idx,dict_idx2spk,mix_speech_len,speech_fre,total_frames,spk_num_total=global_para
    del spk_global_gen
    num_labels=len(spk_all_list)


    #此处顺序是 mix_speechs.shape,mix_feas.shape,aim_fea.shape,aim_spkid.shape,query.shape
    #一个例子:(5, 17040) (5, 134, 129) (5, 134, 129) (5,) (5, 32, 400, 300, 3)
    # datasize=prepare_datasize(data_generator)
    # mix_speech_len,speech_fre,total_frames,spk_num_total,video_size=datasize
    print 'Begin to build the maim model for Multi_Modal Cocktail Problem.'
    # data=data_generator.next()

    # This part is to build the 3D mix speech embedding maps.
    mix_hidden_layer_3d=MIX_SPEECH(speech_fre,mix_speech_len).cuda()
    mix_speech_classifier=MIX_SPEECH_classifier(speech_fre,mix_speech_len,num_labels).cuda()
    mix_speech_multiEmbedding=SPEECH_EMBEDDING(num_labels,config.EMBEDDING_SIZE,spk_num_total+config.UNK_SPK_SUPP).cuda()
    print mix_hidden_layer_3d
    print mix_speech_classifier
    # mix_speech_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(data[1])).cuda())

    hidden_size=(config.EMBEDDING_SIZE)
    # x=torch.arange(0,24).view(2,3,4)
    # y=torch.ones([2,4])
    att_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
    att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
    print att_speech_layer

    optimizer = torch.optim.Adam([{'params':mix_hidden_layer_3d.parameters()},
                                 {'params':mix_speech_multiEmbedding.parameters()},
                                 {'params':mix_speech_classifier.parameters()},
                                 # {'params':query_video_layer.lstm_layer.parameters()},
                                 # {'params':query_video_layer.dense.parameters()},
                                 # {'params':query_video_layer.Linear.parameters()},
                                 {'params':att_layer.parameters()},
                                 {'params':att_speech_layer.parameters()},
                                 # ], lr=0.02,momentum=0.9)
                                 ], lr=0.0002)
    if 1 and config.Load_param:
        # query_video_layer.load_state_dict(torch.load('param_video_layer_19'))
        # mix_speech_classifier.load_state_dict(torch.load('params/param_speech_123onezeroag3_WSJ0_multilabel_epoch40'))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_WSJ0_hidden3d_180'))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_WSJ0_emblayer_180'))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix101_WSJ0_attlayer_180'))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_dbag1nosum_WSJ0_hidden3d_250',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_dbag1nosum_WSJ0_emblayer_250',map_location={'cuda:1':'cuda:0'}))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix101_dbag1nosum_WSJ0_attlayer_250',map_location={'cuda:1':'cuda:0'}))

        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix2or3_db_WSJ0_hidden3d_560',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix2or3_db_WSJ0_emblayer_560',map_location={'cuda:1':'cuda:0'}))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix2or3_db_WSJ0_attlayer_560',map_location={'cuda:1':'cuda:0'}))

        mix_speech_classifier.load_state_dict(torch.load('params/param_speech_4lstm_multilabelloss30map_epoch440'))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_dbag2sum_WSJ0_hidden3d_460',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_dbag2sum_WSJ0_emblayer_460',map_location={'cuda:1':'cuda:0'}))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix101_dbag2sum_WSJ0_attlayer_460',map_location={'cuda:1':'cuda:0'}))

        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_dbdropout_WSJ0_hidden3d_370',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_dbdropout_WSJ0_emblayer_370',map_location={'cuda:1':'cuda:0'}))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix101_dbdropout_WSJ0_attlayer_370',map_location={'cuda:1':'cuda:0'}))

        # att_speech_layer.load_state_dict(torch.load('params/param_mix101_dbdropoutag_WSJ0_attlayer_220',map_location={'cuda:1':'cuda:0'}))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_dbdropoutag_WSJ0_hidden3d_220',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_dbdropoutag_WSJ0_emblayer_220',map_location={'cuda:1':'cuda:0'}))

        # att_speech_layer.load_state_dict(torch.load('params/param_mix2or3_dbdropoutag_WSJ0_attlayer_180',map_location={'cuda:1':'cuda:0'}))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix2or3_dbdropoutag_WSJ0_hidden3d_180',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix2or3_dbdropoutag_WSJ0_emblayer_180',map_location={'cuda:1':'cuda:0'}))

        att_speech_layer.load_state_dict(torch.load('params/param_mix2_db2dropout_WSJ0_attlayer_495',map_location={'cuda:1':'cuda:0'}))
        mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix2_db2dropout_WSJ0_hidden3d_495',map_location={'cuda:1':'cuda:0'}))
        mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix2_db2dropout_WSJ0_emblayer_495',map_location={'cuda:1':'cuda:0'}))

        # att_speech_layer.load_state_dict(torch.load('params/param_mix2or3_db2dropout_WSJ0_attlayer_95',map_location={'cuda:1':'cuda:0'}))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix2or3_db2dropout_WSJ0_hidden3d_95',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix2or3_db2dropout_WSJ0_emblayer_95',map_location={'cuda:1':'cuda:0'}))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix1to3_dbdropoutag1_WSJ0_attlayer_500',map_location={'cuda:1':'cuda:0'}))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix1to3_dbdropoutag1_WSJ0_hidden3d_500',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix1to3_dbdropoutag1_WSJ0_emblayer_500',map_location={'cuda:1':'cuda:0'}))

        # att_speech_layer.load_state_dict(torch.load('params/param_mix2_40lstm2dbdro_WSJ0_attlayer_835',map_location={'cuda:1':'cuda:0'}))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix2_40lstm2dbdro_WSJ0_hidden3d_835',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix2_40lstm2dbdro_WSJ0_emblayer_835',map_location={'cuda:1':'cuda:0'}))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix2_40lstmdbdropout_WSJ0_attlayer_200',map_location={'cuda:1':'cuda:0'}))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix2_40lstm3dbdropout_WSJ0_hidden3d_200',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix2_40lstm3dbdropout_WSJ0_emblayer_200',map_location={'cuda:1':'cuda:0'}))

        '''with Noise'''
        # att_speech_layer.load_state_dict(torch.load('params/param_noicemix2or3_db2dropout_WSJ0_attlayer_80',map_location={'cuda:1':'cuda:0'}))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_noicemix2or3_db2dropout_WSJ0_hidden3d_80',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_noicemix2or3_db2dropout_WSJ0_emblayer_80',map_location={'cuda:1':'cuda:0'}))
    loss_func = torch.nn.MSELoss()  # the target label is NOT an one-hotted
    loss_multi_func = torch.nn.MSELoss()  # the target label is NOT an one-hotted
    # loss_multi_func = torch.nn.L1Loss()  # the target label is NOT an one-hotted
    loss_query_class=torch.nn.CrossEntropyLoss()

    print '''Begin to calculate.'''
    SDR_SUM_total=np.array([])
    for epoch_idx in range(config.MAX_EPOCH):
        if epoch_idx>0:
            print 'SDR_SUM (len:{}) for epoch {} : {}'.format(SDR_SUM.shape,epoch_idx-1,SDR_SUM.mean())
        SDR_SUM=np.array([])
        # print 'SDR_SUM for epoch {}:{}'.format(epoch_idx - 1, SDR_SUM.mean())
        dst='batch_output'
        if os.path.exists(dst):
            print " cleanup: " + dst + "/"
            shutil.rmtree(dst)
        os.makedirs(dst)
        for batch_idx in range(config.EPOCH_SIZE):
            print '*' * 40,epoch_idx,batch_idx,'*'*40
            train_data_gen=prepare_data('once','train')
            # train_data_gen=prepare_data('once','test')
            train_data_gen=prepare_data('once','eval_test')
            train_data=train_data_gen.next()
            mix_feas=train_data['mix_feas']
            '''混合语音len,fre,Emb 3D表示层'''
            mix_speech_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            # 暂时关掉video部分,因为s2 s3 s4 的视频数据不全暂时

            '''Speech self Sepration 语音自分离部分'''
            mix_speech_output=mix_speech_classifier(Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            #从数据里得到ground truth的说话人名字和vector
            # y_spk_list=[one.keys() for one in train_data['multi_spk_fea_list']]
            # y_spk_list= train_data['multi_spk_fea_list']
            # y_spk_gtruth,y_map_gtruth=multi_label_vector(y_spk_list,dict_spk2idx)
            # 如果训练阶段使用Ground truth的分离结果作为判别
            if 0 and config.Ground_truth:
                mix_speech_output=Variable(torch.from_numpy(y_map_gtruth)).cuda()
                if test_all_outputchannel: #把输入的mask改成全1,可以用来测试输出所有的channel
                    mix_speech_output=Variable(torch.ones(config.BATCH_SIZE,num_labels,))
                    y_map_gtruth=np.ones([config.BATCH_SIZE,num_labels])
            recu_spk_list=OrderedDict() #每step对应spk以及分离出来的目标语音
            speech_history=[] #将每step剩余speech 频谱的历史记录下来
            bss_eval_groundtrue(train_data,batch_idx)

            now_feas=train_data['mix_feas']
            while True:
                speech_history.append(now_feas)
                max_num_labels=3
                top_k_mask_mixspeech,top_k_sort_index=top_k_mask(mix_speech_output,alpha=-0.3,top_k=max_num_labels) #torch.Float型的
                # top_k_mask_idx=[np.where(line==1)[0] for line in top_k_mask_mixspeech.numpy()]
                top_k_mask_idx=top_k_sort_index
                #过滤一下,把之前见过的spk过滤掉
                print 'predict spk:',top_k_mask_idx[0]
                for k in top_k_mask_idx[0]:
                    if k not in recu_spk_list.keys():
                        top_k_mask_idx=[[k]]
                        break
                print 'flitered spk:',top_k_mask_idx[0]
                # 如果过滤完了之后啥也没有了,那么就结束了
                if len(top_k_mask_idx[0])==0:
                    break
                # elif top_k_mask_idx[0][0] in speech_history

                mix_speech_multiEmbs=mix_speech_multiEmbedding(top_k_mask_mixspeech,top_k_mask_idx) # bs*num_labels(最多混合人个数)×Embedding的大小

                assert len(top_k_mask_idx[0])==len(top_k_mask_idx[-1])
                top_k_num=len(top_k_mask_idx[0])

                #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
                #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了 
                mix_speech_hidden_5d=mix_speech_hidden.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
                mix_speech_hidden_5d=mix_speech_hidden_5d.expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre,config.EMBEDDING_SIZE).contiguous()
                mix_speech_hidden_5d_last=mix_speech_hidden_5d.view(-1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
                # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
                att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'dot').cuda()
                att_multi_speech=att_speech_layer(mix_speech_hidden_5d_last,mix_speech_multiEmbs.view(-1,config.EMBEDDING_SIZE))
                # print att_multi_speech.size()
                att_multi_speech=att_multi_speech.view(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre) # bs,num_labels,len,fre这个东西
                # print att_multi_speech.size()
                multi_mask=att_multi_speech
                # multi_mask=(att_multi_speech>0.5)
                # multi_mask=Variable(torch.from_numpy(np.float32(multi_mask.data.cpu().numpy()))).cuda()
                # top_k_mask_mixspeech_multi=top_k_mask_mixspeech.view(config.BATCH_SIZE,top_k_num,1,1).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre)
                # multi_mask=multi_mask*Variable(top_k_mask_mixspeech_multi).cuda()

                x_input_map=Variable(torch.from_numpy(now_feas)).cuda()
                # print x_input_map.size()
                x_input_map_multi=x_input_map.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre)
                # predict_multi_map=multi_mask*x_input_map_multi
                predict_multi_map=multi_mask*x_input_map_multi #该说话人预测出来的频谱
                recu_spk_list[top_k_mask_idx[0][0]]=predict_multi_map

                pre_spk=dict_idx2spk[top_k_mask_idx[0][0]]
                num_step=len(recu_spk_list)
                print 'Now output the {} th spk , closest to spk <{}> in train list.'.format(num_step,pre_spk)
                # bss_eval_recu(multi_mask,x_input_map,top_k_mask_mixspeech,pre_spk,train_data,num_step-1,batch_idx)

                if num_step>=2:
                    # bss_eval_recu(multi_mask,x_input_map,top_k_mask_mixspeech,pre_spk,train_data,num_step,batch_idx)
                    break

                now_feas=((1-multi_mask)*x_input_map_multi).data.cpu().numpy().reshape(1,mix_speech_len,speech_fre)
                mix_speech_output=mix_speech_classifier(Variable(torch.from_numpy(now_feas)).cuda())
                mix_speech_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(now_feas)).cuda())


            cal_spk=recu_spk_list.keys()
            mix_speech_multiEmbs=mix_speech_multiEmbedding(top_k_mask_mixspeech,cal_spk) # bs*num_labels(最多混合人个数)×Embedding的大小

            top_k_num=len(cal_spk)

            #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
            #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了 
            mix_speech_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            mix_speech_hidden_5d=mix_speech_hidden.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
            mix_speech_hidden_5d=mix_speech_hidden_5d.expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre,config.EMBEDDING_SIZE).contiguous()
            mix_speech_hidden_5d_last=mix_speech_hidden_5d.view(-1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
            # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
            att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'dot').cuda()
            att_multi_speech=att_speech_layer(mix_speech_hidden_5d_last,mix_speech_multiEmbs.view(-1,config.EMBEDDING_SIZE))
            # print att_multi_speech.size()
            att_multi_speech=att_multi_speech.view(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre) # bs,num_labels,len,fre这个东西
            # print att_multi_speech.size()
            multi_mask=att_multi_speech
            # multi_mask=(att_multi_speech>0.5)
            # multi_mask=Variable(torch.from_numpy(np.float32(multi_mask.data.cpu().numpy()))).cuda()
            # top_k_mask_mixspeech_multi=top_k_mask_mixspeech.view(config.BATCH_SIZE,top_k_num,1,1).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre)
            # multi_mask=multi_mask*Variable(top_k_mask_mixspeech_multi).cuda()

            x_input_map=Variable(torch.from_numpy(train_data['mix_feas'])).cuda()
            # print x_input_map.size()
            x_input_map_multi=x_input_map.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre)
            bss_eval_fromGenMap(multi_mask,x_input_map,top_k_mask_mixspeech,dict_idx2spk,train_data,batch_idx)



        # SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output/', 2))
        # print 'SDR_SUM (len:{}) for epoch {} : {}'.format(SDR_SUM.shape,epoch_idx,SDR_SUM.mean())
        # 1/0
        SDR_SUM_total = np.append(SDR_SUM_total, bss_test.cal('batch_output/', 2))
        print 'SDR_SUM (len:{}) for epoch {} : {}'.format(SDR_SUM_total.shape,epoch_idx,SDR_SUM_total.mean())
Ejemplo n.º 6
0
def train(epoch):
    e = epoch
    model.train()
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])

    if config.schedule and scheduler.get_lr()[0] > 5e-5:
        scheduler.step()
        print("Decaying learning rate to %g" % scheduler.get_lr()[0])
        lera.log({
            'lr': scheduler.get_lr()[0],
        })

    if opt.model == 'gated':
        model.current_epoch = epoch

    global e, updates, total_loss, start_time, report_total, report_correct, total_loss_sgm, total_loss_ss

    train_data_gen = prepare_data('once', 'train')
    while True:
        print '\n'
        train_data = train_data_gen.next()
        if train_data == False:
            print('SDR_aver_epoch:', SDR_SUM.mean())
            print('SDRi_aver_epoch:', SDRi_SUM.mean())
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch

        src = Variable(torch.from_numpy(train_data['mix_feas']))
        # raw_tgt = [spk.keys() for spk in train_data['multi_spk_fea_list']]
        raw_tgt = [
            sorted(spk.keys()) for spk in train_data['multi_spk_fea_list']
        ]
        feas_tgt = models.rank_feas(
            raw_tgt,
            train_data['multi_spk_fea_list'])  # 这里是目标的图谱,aim_size,len,fre

        # 要保证底下这几个都是longTensor(长整数)
        tgt_max_len = config.MAX_MIX + 2  # with bos and eos.
        tgt = Variable(
            torch.from_numpy(
                np.array(
                    [[0] + [dict_spk2idx[spk] for spk in spks] +
                     (tgt_max_len - len(spks) - 1) * [dict_spk2idx['<EOS>']]
                     for spks in raw_tgt],
                    dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。
        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor([
                len(one_spk) for one_spk in train_data['multi_spk_fea_list']
            ])).unsqueeze(0)
        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()

        model.zero_grad()

        # aim_list 就是找到有正经说话人的地方的标号
        aim_list = (tgt[1:-1].transpose(0, 1).contiguous().view(-1) !=
                    dict_spk2idx['<EOS>']).nonzero().squeeze()
        aim_list = aim_list.data.cpu().numpy()

        outputs, targets, multi_mask, gamma = model(
            src, src_len, tgt, tgt_len,
            dict_spk2idx)  # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
        # print('mask size:', multi_mask.size())
        writer.add_histogram('global gamma', gamma, updates)

        if 1 and len(opt.gpus) > 1:
            sgm_loss, num_total, num_correct = model.module.compute_loss(
                outputs, targets, opt.memory)
        else:
            sgm_loss, num_total, num_correct = model.compute_loss(
                outputs, targets, opt.memory)
        print('loss for SGM,this batch:', sgm_loss.cpu().item())
        writer.add_scalars('scalar/loss', {'sgm_loss': sgm_loss.cpu().item()},
                           updates)

        src = src.transpose(0, 1)
        # expand the raw mixed-features to topk_max channel.
        siz = src.size()
        assert len(siz) == 3
        topk_max = config.MAX_MIX  # 最多可能的topk个数
        x_input_map_multi = torch.unsqueeze(src, 1).expand(
            siz[0], topk_max, siz[1],
            siz[2]).contiguous().view(-1, siz[1], siz[2])
        x_input_map_multi = x_input_map_multi[aim_list]
        multi_mask = multi_mask.transpose(0, 1)

        if 1 and len(opt.gpus) > 1:
            ss_loss = model.module.separation_loss(x_input_map_multi,
                                                   multi_mask, feas_tgt)
        else:
            ss_loss = model.separation_loss(x_input_map_multi, multi_mask,
                                            feas_tgt)
        print('loss for SS,this batch:', ss_loss.cpu().item())
        writer.add_scalars('scalar/loss', {'ss_loss': ss_loss.cpu().item()},
                           updates)

        loss = sgm_loss + 5 * ss_loss

        loss.backward()
        # print 'totallllllllllll loss:',loss
        total_loss_sgm += sgm_loss.cpu().item()
        total_loss_ss += ss_loss.cpu().item()
        lera.log({
            'sgm_loss': sgm_loss.cpu().item(),
            'ss_loss': ss_loss.cpu().item(),
            'loss:': loss.cpu().item(),
        })

        if updates > 10 and updates % config.eval_interval in [
                0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
        ]:
            predicted_maps = multi_mask * x_input_map_multi
            # predicted_maps=Variable(feas_tgt)
            utils.bss_eval(config,
                           predicted_maps,
                           train_data['multi_spk_fea_list'],
                           raw_tgt,
                           train_data,
                           dst='batch_output')
            del predicted_maps, multi_mask, x_input_map_multi
            sdr_aver_batch, sdri_aver_batch = bss_test.cal('batch_output/')
            lera.log({'SDR sample': sdr_aver_batch})
            lera.log({'SDRi sample': sdri_aver_batch})
            writer.add_scalars('scalar/loss', {
                'SDR_sample': sdr_aver_batch,
                'SDRi_sample': sdri_aver_batch
            }, updates)
            SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
            SDRi_SUM = np.append(SDRi_SUM, sdri_aver_batch)
            print('SDR_aver_now:', SDR_SUM.mean())
            print('SDRi_aver_now:', SDRi_SUM.mean())

        total_loss += loss.cpu().item()
        report_correct += num_correct.cpu().item()
        report_total += num_total.cpu().item()
        optim.step()

        updates += 1
        if updates % 30 == 0:
            logging(
                "time: %6.3f, epoch: %3d, updates: %8d, train loss this batch: %6.3f,sgm loss: %6.6f,ss loss: %6.6f,label acc: %6.6f\n"
                % (time.time() - start_time, epoch, updates, loss / num_total,
                   total_loss_sgm / 30.0, total_loss_ss / 30.0,
                   report_correct / report_total))
            lera.log({'label_acc': report_correct / report_total})
            writer.add_scalars('scalar/loss',
                               {'label_acc': report_correct / report_total},
                               updates)
            total_loss_sgm, total_loss_ss = 0, 0

        # continue

        if 0 and updates % config.eval_interval == 0 and epoch > 3:  #建议至少跑几个epoch再进行测试,否则模型还没学到东西,会有很多问题。
            logging(
                "time: %6.3f, epoch: %3d, updates: %8d, train loss: %6.5f\n" %
                (time.time() - start_time, epoch, updates,
                 total_loss / report_total))
            print('evaluating after %d updates...\r' % updates)
            original_bs = config.batch_size
            score = eval(epoch)  # eval的时候batch_size会变成1
            print 'Orignal bs:', original_bs
            config.batch_size = original_bs
            print 'Now bs:', config.batch_size
            for metric in config.metric:
                scores[metric].append(score[metric])
                lera.log({
                    'sgm_micro_f1': score[metric],
                })
                if metric == 'micro_f1' and score[metric] >= max(
                        scores[metric]):
                    save_model(log_path + 'best_' + metric + '_checkpoint.pt')
                if metric == 'hamming_loss' and score[metric] <= min(
                        scores[metric]):
                    save_model(log_path + 'best_' + metric + '_checkpoint.pt')

            model.train()
            total_loss = 0
            start_time = 0
            report_total = 0
            report_correct = 0

        if updates % config.save_interval == 1:
            save_model(log_path + 'TDAAv3_{}.pt'.format(updates))
Ejemplo n.º 7
0
def train(epoch):
    global e, updates, total_loss, start_time, report_total, report_correct, total_loss_sgm, total_loss_ss
    e = epoch
    model.train()
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])

    if updates <= config.warmup:  #如果不在warm阶段就正常规划
        pass
    elif config.schedule and scheduler.get_lr()[0] > 5e-7:
        scheduler.step()
        print(("Decaying learning rate to %g" % scheduler.get_lr()[0]))
        lera.log({
            'lr': [group['lr'] for group in optim.optimizer.param_groups][0],
        })

    if opt.model == 'gated':
        model.current_epoch = epoch

    train_data_gen = prepare_data('once', 'train')
    while True:
        if updates <= config.warmup:  # 如果在warm就开始warmup
            tmp_lr = config.learning_rate * min(
                max(updates, 1)**(-0.5),
                max(updates, 1) * (config.warmup**(-1.5)))
            for param_group in optim.optimizer.param_groups:
                param_group['lr'] = tmp_lr
            scheduler.base_lrs = list(
                [group['lr'] for group in optim.optimizer.param_groups])
            if updates % 100 == 0:  #记录一下
                print(updates)
                print("Warmup learning rate to %g" % tmp_lr)
                lera.log({
                    'lr':
                    [group['lr'] for group in optim.optimizer.param_groups][0],
                })

        train_data = next(train_data_gen)
        if train_data == False:
            print(('SDR_aver_epoch:', SDR_SUM.mean()))
            print(('SDRi_aver_epoch:', SDRi_SUM.mean()))
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch

        src = Variable(torch.from_numpy(train_data['mix_feas']))
        # raw_tgt = [spk.keys() for spk in train_data['multi_spk_fea_list']]
        # raw_tgt = [sorted(spk.keys()) for spk in train_data['multi_spk_fea_list']]
        raw_tgt = train_data['batch_order']
        feas_tgt = models.rank_feas(
            raw_tgt,
            train_data['multi_spk_fea_list'])  # 这里是目标的图谱,aim_size,len,fre

        padded_mixture, mixture_lengths, padded_source = train_data['tas_zip']
        padded_mixture = torch.from_numpy(padded_mixture).float()
        mixture_lengths = torch.from_numpy(mixture_lengths)
        padded_source = torch.from_numpy(padded_source).float()

        padded_mixture = padded_mixture.cuda().transpose(0, 1)
        mixture_lengths = mixture_lengths.cuda()
        padded_source = padded_source.cuda()

        # 要保证底下这几个都是longTensor(长整数)
        tgt_max_len = config.MAX_MIX + 2  # with bos and eos.
        tgt = Variable(
            torch.from_numpy(
                np.array(
                    [[0] + [dict_spk2idx[spk] for spk in spks] +
                     (tgt_max_len - len(spks) - 1) * [dict_spk2idx['<EOS>']]
                     for spks in raw_tgt],
                    dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。
        # tgt = Variable(torch.from_numpy(np.array([[0,1,2,102] for __ in range(config.batch_size)], dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。
        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor([
                len(one_spk) for one_spk in train_data['multi_spk_fea_list']
            ])).unsqueeze(0)
        if config.WFM:
            siz = src.size()  # bs,T,F
            assert len(siz) == 3
            # topk_max = config.MAX_MIX  # 最多可能的topk个数
            topk_max = 2  # 最多可能的topk个数
            x_input_map_multi = torch.unsqueeze(src, 1).expand(
                siz[0], topk_max, siz[1],
                siz[2]).contiguous().view(-1, siz[1], siz[2])  # bs,topk,T,F
            feas_tgt_tmp = feas_tgt.view(siz[0], -1, siz[1], siz[2])

            feas_tgt_square = feas_tgt_tmp * feas_tgt_tmp
            feas_tgt_sum_square = torch.sum(feas_tgt_square,
                                            dim=1,
                                            keepdim=True).expand(
                                                siz[0], topk_max, siz[1],
                                                siz[2])
            WFM_mask = feas_tgt_square / (feas_tgt_sum_square + 1e-15)
            feas_tgt = x_input_map_multi.view(
                siz[0], -1, siz[1], siz[2]).data * WFM_mask  # bs,topk,T,F
            feas_tgt = feas_tgt.view(-1, siz[1], siz[2])  # bs*topk,T,F
            WFM_mask = WFM_mask.cuda()
            del x_input_map_multi

        elif config.PSM:
            siz = src.size()  # bs,T,F
            assert len(siz) == 3
            # topk_max = config.MAX_MIX  # 最多可能的topk个数
            topk_max = 2  # 最多可能的topk个数
            x_input_map_multi = torch.unsqueeze(src, 1).expand(
                siz[0], topk_max, siz[1], siz[2]).contiguous()  # bs,topk,T,F
            feas_tgt_tmp = feas_tgt.view(siz[0], -1, siz[1], siz[2])

            IRM = feas_tgt_tmp / (x_input_map_multi + 1e-15)

            angle_tgt = models.rank_feas(
                raw_tgt, train_data['multi_spk_angle_list']).view(
                    siz[0], -1, siz[1], siz[2])
            angle_mix = Variable(
                torch.from_numpy(np.array(
                    train_data['mix_angle']))).unsqueeze(1).expand(
                        siz[0], topk_max, siz[1], siz[2]).contiguous()
            ang = np.cos(angle_mix - angle_tgt)
            ang = np.clip(ang, 0, None)

            feas_tgt = x_input_map_multi * IRM * ang  # bs,topk,T,F
            feas_tgt = feas_tgt.view(-1, siz[1], siz[2])  # bs*topk,T,F
            del x_input_map_multi

        elif config.frame_mask:
            siz = src.size()  # bs,T,F
            assert len(siz) == 3
            # topk_max = config.MAX_MIX  # 最多可能的topk个数
            topk_max = 2  # 最多可能的topk个数
            x_input_map_multi = torch.unsqueeze(src, 1).expand(
                siz[0], topk_max, siz[1], siz[2]).contiguous()  # bs,topk,T,F
            feas_tgt_tmp = feas_tgt.view(siz[0], -1, siz[1], siz[2])

            feas_tgt_time = torch.sum(feas_tgt_tmp, 3).transpose(1,
                                                                 2)  #bs,T,topk
            for v1 in feas_tgt_time:
                for v2 in v1:
                    if v2[0] > v2[1]:
                        v2[0] = 1
                        v2[1] = 0
                    else:
                        v2[0] = 0
                        v2[1] = 1
            frame_mask = feas_tgt_time.transpose(1,
                                                 2).unsqueeze(-1)  #bs,topk,t,1
            feas_tgt = x_input_map_multi * frame_mask
            feas_tgt = feas_tgt.view(-1, siz[1], siz[2])  # bs*topk,T,F

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()

        model.zero_grad()
        if config.use_center_loss:
            center_loss.zero_grad()

        # aim_list 就是找到有正经说话人的地方的标号
        aim_list = (tgt[1:-1].transpose(0, 1).contiguous().view(-1) !=
                    dict_spk2idx['<EOS>']).nonzero().squeeze()
        aim_list = aim_list.data.cpu().numpy()

        outputs, pred, targets, multi_mask, dec_enc_attn_list = model(
            src,
            src_len,
            tgt,
            tgt_len,
            dict_spk2idx,
            None,
            mix_wav=padded_mixture
        )  # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
        print('mask size:', multi_mask.size())
        # writer.add_histogram('global gamma',gamma, updates)

        src = src.transpose(0, 1)
        # expand the raw mixed-features to topk_max channel.
        siz = src.size()
        assert len(siz) == 3
        topk_max = config.MAX_MIX  # 最多可能的topk个数
        x_input_map_multi = torch.unsqueeze(src, 1).expand(
            siz[0], topk_max, siz[1],
            siz[2]).contiguous()  #.view(-1, siz[1], siz[2])
        # x_input_map_multi = x_input_map_multi[aim_list]
        multi_mask = multi_mask.transpose(0, 1)
        # if config.WFM:
        #     feas_tgt = x_input_map_multi.data * WFM_mask

        if config.use_tas:
            if 1 and len(opt.gpus) > 1:
                ss_loss, pmt_list, max_snr_idx, *__ = model.module.separation_tas_loss(
                    padded_mixture, multi_mask, padded_source, mixture_lengths)
            else:
                ss_loss, pmt_list, max_snr_idx, *__ = model.separation_tas_loss(
                    padded_mixture, multi_mask, padded_source, mixture_lengths)
            best_pmt = [
                list(pmt_list[int(mm)].data.cpu().numpy())
                for mm in max_snr_idx
            ]
        else:
            if 1 and len(opt.gpus) > 1:  # 先ss获取Perm
                ss_loss, best_pmt = model.module.separation_pit_loss(
                    x_input_map_multi, multi_mask, feas_tgt)
            else:
                ss_loss, best_pmt = model.separation_pit_loss(
                    x_input_map_multi, multi_mask, feas_tgt)

        print('loss for SS,this batch:', ss_loss.cpu().item())
        print('best perms for this batch:', best_pmt)
        writer.add_scalars('scalar/loss', {'ss_loss': ss_loss.cpu().item()},
                           updates)

        # 按照Best_perm重新排列spk的预测目标
        targets = targets.transpose(0, 1)  #bs,aim+1(EOS也在)
        # print('targets',targets)
        targets_old = targets
        for idx, (tar, per) in enumerate(zip(targets, best_pmt)):
            per.append(topk_max)  #每个batch后面加个结尾,保持最后一个EOS不变
            targets_old[idx] = tar[per]
        targets = targets_old.transpose(0, 1)
        # print('targets',targets)

        if 1 and len(opt.gpus) > 1:
            sgm_loss, num_total, num_correct = model.module.compute_loss(
                outputs, targets, opt.memory)
        else:
            sgm_loss, num_total, num_correct = model.compute_loss(
                outputs, targets, opt.memory)
        print(('loss for SGM,this batch:', sgm_loss.cpu().item()))
        writer.add_scalars('scalar/loss', {'sgm_loss': sgm_loss.cpu().item()},
                           updates)
        if config.use_center_loss:
            cen_alpha = 0.01
            cen_loss = center_loss(outputs.view(-1, config.SPK_EMB_SIZE),
                                   targets.view(-1))
            print(('loss for SGM center loss,this batch:',
                   cen_loss.cpu().item()))
            writer.add_scalars('scalar/loss',
                               {'center_loss': cen_loss.cpu().item()}, updates)

        if not config.use_tas:
            loss = sgm_loss + 5 * ss_loss
        else:
            loss = 50 * sgm_loss + ss_loss

        loss.backward()

        if config.use_center_loss:
            for c_param in center_loss.parameters():
                c_param.grad.data *= (0.01 /
                                      (cen_alpha * scheduler.get_lr()[0]))
        # print 'totallllllllllll loss:',loss
        total_loss_sgm += sgm_loss.cpu().item()
        total_loss_ss += ss_loss.cpu().item()
        lera.log({
            'sgm_loss': sgm_loss.cpu().item(),
            'ss_loss': ss_loss.cpu().item(),
            'loss:': loss.cpu().item(),
        })

        if updates > 10 and updates % config.eval_interval in [
                0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
        ]:
            if not config.use_tas:
                predicted_maps = multi_mask * x_input_map_multi.view(
                    siz[0] * topk_max, siz[1], siz[2])
                # predicted_maps=Variable(feas_tgt) # 这个是groundTruth
                utils.bss_eval(config,
                               predicted_maps,
                               train_data['multi_spk_fea_list'],
                               raw_tgt,
                               train_data,
                               dst='batch_output1')
                del predicted_maps, multi_mask, x_input_map_multi
                sdr_aver_batch, sdri_aver_batch = bss_test.cal(
                    'batch_output1/')
            else:
                utils.bss_eval_tas(config,
                                   multi_mask,
                                   train_data['multi_spk_fea_list'],
                                   raw_tgt,
                                   train_data,
                                   dst='batch_output1')
                del x_input_map_multi
                sdr_aver_batch, sdri_aver_batch = bss_test.cal(
                    'batch_output1/')
            lera.log({'SDR sample': sdr_aver_batch})
            lera.log({'SDRi sample': sdri_aver_batch})
            writer.add_scalars('scalar/loss', {
                'SDR_sample': sdr_aver_batch,
                'SDRi_sample': sdri_aver_batch
            }, updates)
            SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
            SDRi_SUM = np.append(SDRi_SUM, sdri_aver_batch)
            print(('SDR_aver_now:', SDR_SUM.mean()))
            print(('SDRi_aver_now:', SDRi_SUM.mean()))

        total_loss += loss.cpu().item()
        report_correct += num_correct.cpu().item()
        report_total += num_total.cpu().item()
        optim.step()

        updates += 1
        if updates % 30 == 0:
            logging(
                "time: %6.3f, epoch: %3d, updates: %8d, train loss this batch: %6.3f,sgm loss: %6.6f,ss loss: %6.6f,label acc: %6.6f\n"
                % (time.time() - start_time, epoch, updates, loss / num_total,
                   total_loss_sgm / 30.0, total_loss_ss / 30.0,
                   report_correct / report_total))
            lera.log({'label_acc': report_correct / report_total})
            writer.add_scalars('scalar/loss',
                               {'label_acc': report_correct / report_total},
                               updates)
            total_loss_sgm, total_loss_ss = 0, 0

        # continue

        if 0 and updates % config.eval_interval == 0 and epoch > 3:  #建议至少跑几个epoch再进行测试,否则模型还没学到东西,会有很多问题。
            logging(
                "time: %6.3f, epoch: %3d, updates: %8d, train loss: %6.5f\n" %
                (time.time() - start_time, epoch, updates,
                 total_loss / report_total))
            print(('evaluating after %d updates...\r' % updates))
            original_bs = config.batch_size
            score = eval(epoch)  # eval的时候batch_size会变成1
            # print 'Orignal bs:',original_bs
            config.batch_size = original_bs
            # print 'Now bs:',config.batch_size
            for metric in config.metric:
                scores[metric].append(score[metric])
                lera.log({
                    'sgm_micro_f1': score[metric],
                })
                if metric == 'micro_f1' and score[metric] >= max(
                        scores[metric]):
                    save_model(log_path + 'best_' + metric + '_checkpoint.pt')
                if metric == 'hamming_loss' and score[metric] <= min(
                        scores[metric]):
                    save_model(log_path + 'best_' + metric + '_checkpoint.pt')

            model.train()
            total_loss = 0
            start_time = 0
            report_total = 0
            report_correct = 0

        if 1 and updates % config.save_interval == 1:
            save_model(log_path + 'TDAAv3_PIT_{}.pt'.format(updates))
Ejemplo n.º 8
0
def eval(epoch):
    model.eval()
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    test_or_valid = 'test'
    print 'Test or valid:', test_or_valid
    eval_data_gen = prepare_data_aim('once', test_or_valid, config.MIN_MIX,
                                     config.MAX_MIX)
    # for raw_src, src, src_len, raw_tgt, tgt, tgt_len in validloader:
    SDR_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    while True:
        # for ___ in range(2):
        print '-' * 30
        eval_data = eval_data_gen.next()
        if eval_data == False:
            print 'SDR_aver_eval_epoch:', SDR_SUM.mean()
            break  #如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_feas']))

        raw_tgt = [
            sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']
        ]
        feas_tgt = models.rank_feas(raw_tgt,
                                    eval_data['multi_spk_fea_list'])  #这里是目标的图谱

        top_k = len(raw_tgt[0])
        # 要保证底下这几个都是longTensor(长整数)
        # tgt = Variable(torch.from_numpy(np.array([[0]+[dict_spk2idx[spk] for spk in spks]+[dict_spk2idx['<EOS>']] for spks in raw_tgt],dtype=np.int))).transpose(0,1) #转换成数字,然后前后加开始和结束符号。
        tgt = Variable(torch.ones(
            top_k + 2, config.batch_size))  # 这里随便给一个tgt,为了测试阶段tgt的名字无所谓其实。

        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor([
                len(one_spk) for one_spk in eval_data['multi_spk_fea_list']
            ])).unsqueeze(0)
        # tgt_len = Variable(torch.LongTensor(config.batch_size).zero_()+len(eval_data['multi_spk_fea_list'][0])).unsqueeze(0)
        if config.WFM:
            tmp_size = feas_tgt.size()
            assert len(tmp_size) == 4
            feas_tgt_sum = torch.sum(feas_tgt, dim=1, keepdim=True)
            feas_tgt_sum_square = (feas_tgt_sum *
                                   feas_tgt_sum).expand(tmp_size)
            feas_tgt_square = feas_tgt * feas_tgt
            WFM_mask = feas_tgt_square / feas_tgt_sum_square

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()
            if config.WFM:
                WFM_mask = WFM_mask.cuda()
        try:
            if 1 and len(opt.gpus) > 1:
                # samples, alignment = model.module.sample(src, src_len)
                samples, alignment, hiddens, predicted_masks = model.module.beam_sample(
                    src,
                    src_len,
                    dict_spk2idx,
                    tgt,
                    beam_size=config.beam_size)
            else:
                samples, alignment, hiddens, predicted_masks = model.beam_sample(
                    src,
                    src_len,
                    dict_spk2idx,
                    tgt,
                    beam_size=config.beam_size)
                # samples, alignment, hiddens, predicted_masks = model.beam_sample(src, src_len, dict_spk2idx, tgt, beam_size=config.beam_size)
        except TabError, info:
            print '**************Error occurs here************:', info
            continue

        if config.top1:
            predicted_masks = torch.cat([predicted_masks, 1 - predicted_masks],
                                        1)

        # '''
        # expand the raw mixed-features to topk_max channel.
        src = src.transpose(0, 1)
        siz = src.size()
        assert len(siz) == 3
        topk_max = feas_tgt.size()[1]
        assert samples[0][-1] == dict_spk2idx['<EOS>']
        topk_max = len(samples[0]) - 1
        x_input_map_multi = torch.unsqueeze(src,
                                            1).expand(siz[0], topk_max, siz[1],
                                                      siz[2])
        if config.WFM:
            feas_tgt = x_input_map_multi.data * WFM_mask

        if test_or_valid == 'valid':
            if 1 and len(opt.gpus) > 1:
                ss_loss = model.module.separation_loss(x_input_map_multi,
                                                       predicted_masks,
                                                       feas_tgt, Var)
            else:
                ss_loss = model.separation_loss(x_input_map_multi,
                                                predicted_masks, feas_tgt)
            print 'loss for ss,this batch:', ss_loss.data[0]
            lera.log({
                'ss_loss_' + test_or_valid: ss_loss.data[0],
            })
            del ss_loss, hiddens

        # '''''
        if batch_idx <= (500 / config.batch_size
                         ):  #only the former batches counts the SDR
            predicted_maps = predicted_masks * x_input_map_multi
            # predicted_maps=Variable(feas_tgt)
            utils.bss_eval2(config,
                            predicted_maps,
                            eval_data['multi_spk_fea_list'],
                            raw_tgt,
                            eval_data,
                            dst='batch_outputjaa')
            del predicted_maps, predicted_masks, x_input_map_multi
            SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_outputjaa/'))
            print 'SDR_aver_now:', SDR_SUM.mean()
            lera.log({'SDR sample': SDR_SUM.mean()})
            # raw_input('Press any key to continue......')
        elif batch_idx == (500 / config.batch_size) + 1 and SDR_SUM.mean(
        ) > best_SDR:  #only record the best SDR once.
            print 'Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean())
            best_SDR = SDR_SUM.mean()
            # save_model(log_path+'checkpoint_bestSDR{}.pt'.format(best_SDR))

        # '''
        candidate += [
            convertToLabels(dict_idx2spk, s, dict_spk2idx['<EOS>'])
            for s in samples
        ]
        # source += raw_src
        reference += raw_tgt
        print 'samples:', samples
        print 'can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:],
                                        reference[-1 * config.batch_size:])
        alignments += [align for align in alignment]
        batch_idx += 1
Ejemplo n.º 9
0
def train(epoch):
    global e, updates, total_loss, start_time, report_total,report_correct, total_loss_sgm, total_loss_ss
    e = epoch
    model.train()
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])

    if updates<=config.warmup: #如果不在warm阶段就正常规划
       pass
    elif config.schedule and scheduler.get_lr()[0]>4e-5:
        scheduler.step()
        print(("Decaying learning rate to %g" % scheduler.get_lr()[0],updates))
        lera.log({
            'lr': [group['lr'] for group in optim.optimizer.param_groups][0],
        })

    if opt.model == 'gated':
        model.current_epoch = epoch


    train_data_gen = prepare_data('once', 'train')
    while True:
        if updates <= config.warmup:  # 如果在warm就开始warmup
            tmp_lr =  config.learning_rate * min(max(updates,1)** (-0.5),
                                             max(updates,1) * (config.warmup ** (-1.5)))
            for param_group in optim.optimizer.param_groups:
                param_group['lr'] = tmp_lr
            scheduler.base_lrs=list([group['lr'] for group in optim.optimizer.param_groups])
            if updates%100==0: #记录一下
                print(updates)
                print("Warmup learning rate to %g" % tmp_lr)
                lera.log({
                    'lr': [group['lr'] for group in optim.optimizer.param_groups][0],
                })

        train_data = next(train_data_gen)
        if train_data == False:
            print(('SDR_aver_epoch:', SDR_SUM.mean()))
            print(('SDRi_aver_epoch:', SDRi_SUM.mean()))
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch

        src = Variable(torch.from_numpy(train_data['mix_feas']))
        # raw_tgt = [spk.keys() for spk in train_data['multi_spk_fea_list']]
        # raw_tgt = [sorted(spk.keys()) for spk in train_data['multi_spk_fea_list']]
        raw_tgt=train_data['batch_order']
        feas_tgt = models.rank_feas(raw_tgt, train_data['multi_spk_fea_list'])  # 这里是目标的图谱,bs*Topk,len,fre

        # 要保证底下这几个都是longTensor(长整数)
        tgt_max_len = config.MAX_MIX + 2  # with bos and eos.
        tgt = Variable(torch.from_numpy(np.array(
            [[0] + [dict_spk2idx[spk] for spk in spks] + (tgt_max_len - len(spks) - 1) * [dict_spk2idx['<EOS>']] for
             spks in raw_tgt], dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。
        # tgt = Variable(torch.from_numpy(np.array([[0,1,2,102] for __ in range(config.batch_size)], dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。
        src_len = Variable(torch.LongTensor(config.batch_size).zero_() + mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor([len(one_spk) for one_spk in train_data['multi_spk_fea_list']])).unsqueeze(0)
        if config.WFM:
            siz = src.size()  # bs,T,F
            assert len(siz) == 3
            # topk_max = config.MAX_MIX  # 最多可能的topk个数
            topk_max = 2  # 最多可能的topk个数
            x_input_map_multi = torch.unsqueeze(src, 1).expand(siz[0], topk_max, siz[1], siz[2]).contiguous().view(-1, siz[1], siz[ 2])  # bs,topk,T,F
            feas_tgt_tmp = feas_tgt.view(siz[0], -1, siz[1], siz[2])

            feas_tgt_square = feas_tgt_tmp * feas_tgt_tmp
            feas_tgt_sum_square = torch.sum(feas_tgt_square, dim=1, keepdim=True).expand(siz[0], topk_max, siz[1], siz[2])
            WFM_mask = feas_tgt_square / (feas_tgt_sum_square + 1e-15)
            feas_tgt = x_input_map_multi.view(siz[0], -1, siz[1], siz[2]).data * WFM_mask  # bs,topk,T,F
            feas_tgt = feas_tgt.view(-1, siz[1], siz[2])  # bs*topk,T,F
            WFM_mask = WFM_mask.cuda()
            del x_input_map_multi

        elif config.PSM:
            siz = src.size()  # bs,T,F
            assert len(siz) == 3
            # topk_max = config.MAX_MIX  # 最多可能的topk个数
            topk_max = 2  # 最多可能的topk个数
            x_input_map_multi = torch.unsqueeze(src, 1).expand(siz[0], topk_max, siz[1], siz[2]).contiguous()  # bs,topk,T,F
            feas_tgt_tmp = feas_tgt.view(siz[0], -1, siz[1], siz[2])

            IRM=feas_tgt_tmp/(x_input_map_multi+1e-15)

            angle_tgt=models.rank_feas(raw_tgt, train_data['multi_spk_angle_list']).view(siz[0],-1,siz[1],siz[2]) # bs,topk,T,F
            angle_mix=Variable(torch.from_numpy(np.array(train_data['mix_angle']))).unsqueeze(1).expand(siz[0], topk_max, siz[1], siz[2]).contiguous()
            ang=np.cos(angle_mix-angle_tgt)
            ang=np.clip(ang,0,None)

            # feas_tgt = x_input_map_multi *np.clip(IRM.numpy()*ang,0,1) # bs,topk,T,F
            # feas_tgt = x_input_map_multi *IRM*ang # bs,topk,T,F
            feas_tgt = feas_tgt.view(siz[0],-1,siz[1],siz[2])*ang # bs,topk,T,F
            feas_tgt = feas_tgt.view(-1, siz[1], siz[2])  # bs*topk,T,F
            del x_input_map_multi

        elif config.frame_mask:
            siz = src.size()  # bs,T,F
            assert len(siz) == 3
            # topk_max = config.MAX_MIX  # 最多可能的topk个数
            topk_max = 2  # 最多可能的topk个数
            x_input_map_multi = torch.unsqueeze(src, 1).expand(siz[0], topk_max, siz[1], siz[2]).contiguous()  # bs,topk,T,F
            feas_tgt_tmp = feas_tgt.view(siz[0], -1, siz[1], siz[2])

            feas_tgt_time=torch.sum(feas_tgt_tmp,3).transpose(1,2) #bs,T,topk
            for v1 in feas_tgt_time:
                for v2 in v1:
                    if v2[0]>v2[1]:
                        v2[0]=1
                        v2[1]=0
                    else:
                        v2[0]=0
                        v2[1]=1
            frame_mask=feas_tgt_time.transpose(1,2).unsqueeze(-1) #bs,topk,t,1
            feas_tgt=x_input_map_multi*frame_mask
            feas_tgt = feas_tgt.view(-1, siz[1], siz[2])  # bs*topk,T,F


        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()

        model.zero_grad()
        if config.use_center_loss:
            center_loss.zero_grad()

        # aim_list 就是找到有正经说话人的地方的标号
        aim_list = (tgt[1:-1].transpose(0, 1).contiguous().view(-1) != dict_spk2idx['<EOS>']).nonzero().squeeze()
        aim_list = aim_list.data.cpu().numpy()

        multi_mask, enc_attn_list = model(src, src_len, tgt, tgt_len,
                                             dict_spk2idx)  # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
        print('mask size:', multi_mask.size()) # topk,bs,T,F
        # print('mask:', multi_mask[0,0,:3:3]) # topk,bs,T,F
        # writer.add_histogram('global gamma',gamma, updates)


        src = src.transpose(0, 1)
        # expand the raw mixed-features to topk_max channel.
        siz = src.size()
        assert len(siz) == 3
        topk_max = config.MAX_MIX  # 最多可能的topk个数
        x_input_map_multi = torch.unsqueeze(src, 1).expand(siz[0], topk_max, siz[1], siz[2]).contiguous()#.view(-1, siz[1], siz[2])
        # x_input_map_multi = x_input_map_multi[aim_list]
        # x_input_map_multi = x_input_map_multi.transpose(0, 1) #topk,bs,T,F
        multi_mask = multi_mask.transpose(0, 1)
        # if config.WFM:
        #     feas_tgt = x_input_map_multi.data * WFM_mask

        # 注意,bs是第二维
        assert multi_mask.shape == x_input_map_multi.shape
        assert multi_mask.size(0) == config.batch_size

        if 1 and len(opt.gpus) > 1: #先ss获取Perm
            ss_loss, best_pmt = model.module.separation_pit_loss(x_input_map_multi, multi_mask, feas_tgt)
        else:
            ss_loss, best_pmt = model.separation_pit_loss(x_input_map_multi, multi_mask, feas_tgt)
        print('loss for SS,this batch:', ss_loss.cpu().item())
        print('best perms for this batch:', best_pmt)
        writer.add_scalars('scalar/loss',{'ss_loss':ss_loss.cpu().item()},updates)

        loss = ss_loss
        loss.backward()

        total_loss_ss += ss_loss.cpu().item()
        lera.log({
            'ss_loss': ss_loss.cpu().item(),
        })

        if updates>3 and updates % config.eval_interval in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9,]:
            assert multi_mask.shape==x_input_map_multi.shape
            assert multi_mask.size(0)==config.batch_size
            predicted_maps = (multi_mask * x_input_map_multi).view(siz[0]*topk_max,siz[1],siz[2])

            # predicted_maps=Variable(feas_tgt)
            # utils.bss_eval(config, predicted_maps, train_data['multi_spk_fea_list'], raw_tgt, train_data, dst=log_path+'batch_output/')
            utils.bss_eval2(config, predicted_maps, train_data['multi_spk_fea_list'], raw_tgt, train_data, dst=log_path+'batch_output')
            del predicted_maps, multi_mask, x_input_map_multi
            sdr_aver_batch, sdri_aver_batch=  bss_test.cal(log_path+'batch_output/')
            lera.log({'SDR sample': sdr_aver_batch})
            lera.log({'SDRi sample': sdri_aver_batch})
            writer.add_scalars('scalar/loss',{'SDR_sample':sdr_aver_batch,'SDRi_sample':sdri_aver_batch},updates)
            SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
            SDRi_SUM = np.append(SDRi_SUM, sdri_aver_batch)
            print(('SDR_aver_now:', SDR_SUM.mean()))
            print(('SDRi_aver_now:', SDRi_SUM.mean()))

            # Heatmap here
            # n_layer个 (head*bs) x lq x dk
            '''
            import matplotlib.pyplot as plt
            ax = plt.gca()
            ax.invert_yaxis()

            raw_src=models.rank_feas(raw_tgt, train_data['multi_spk_fea_list'])
            att_idx=1
            att = enc_attn_list[-1].view(config.trans_n_head,config.batch_size,mix_speech_len,mix_speech_len).data.cpu().numpy()[:,att_idx]
            for head in range(config.trans_n_head):
                xx=att[head]
                plt.matshow(xx, cmap=plt.cm.hot, vmin=0,vmax=0.05)
                plt.colorbar()
                plt.savefig(log_path+'batch_output/'+'head_{}.png'.format(head))
            plt.matshow(raw_src[att_idx*2+0].transpose(0,1), cmap=plt.cm.hot, vmin=0,vmax=2)
            plt.colorbar()
            plt.savefig(log_path+'batch_output/'+'source0.png')
            plt.matshow(raw_src[att_idx*2+1].transpose(0,1), cmap=plt.cm.hot, vmin=0,vmax=2)
            plt.colorbar()
            plt.savefig(log_path+'batch_output/'+'source1.png')
            1/0
            '''

        total_loss += loss.cpu().item()
        optim.step()

        updates += 1
        if updates % 30 == 0:
            logging(
                "time: %6.3f, epoch: %3d, updates: %8d, train loss this batch: %6.3f,ss loss: %6.6f\n"
                % (time.time() - start_time, epoch, updates, loss , total_loss_ss / 30.0))
            total_loss_sgm, total_loss_ss = 0, 0

        # continue

        if 0 and updates % config.eval_interval == 0 and epoch > 3: #建议至少跑几个epoch再进行测试,否则模型还没学到东西,会有很多问题。
            logging("time: %6.3f, epoch: %3d, updates: %8d, train loss: %6.5f\n"
                    % (time.time() - start_time, epoch, updates, total_loss/config.eval_interval))
            print(('evaluating after %d updates...\r' % updates))
            eval(epoch,'valid') # eval的时候batch_size会变成1
            eval(epoch,'test') # eval的时候batch_size会变成1

            model.train()
            total_loss = 0
            start_time = 0
            report_total = 0
            report_correct = 0

        if 1 and updates % config.save_interval == 1:
            save_model(log_path + 'Transformer_PIT_{}.pt'.format(updates))
Ejemplo n.º 10
0
def main():
    print('go to model')
    print '*' * 80

    spk_global_gen=prepare_data(mode='global',train_or_test='train') #写一个假的数据生成,可以用来写模型先
    global_para=spk_global_gen.next()
    print global_para
    spk_all_list,dict_spk2idx,dict_idx2spk,mix_speech_len,speech_fre,total_frames,spk_num_total,batch_total=global_para
    del spk_global_gen
    num_labels=len(spk_all_list)

    print 'Begin to build the maim model for Multi_Modal Cocktail Problem.'

    # This part is to build the 3D mix speech embedding maps.
    mix_hidden_layer_3d=MIX_SPEECH(speech_fre,mix_speech_len).cuda()
    mix_speech_classifier=MIX_SPEECH_classifier(speech_fre,mix_speech_len,num_labels).cuda()
    mix_speech_multiEmbedding=SPEECH_EMBEDDING(num_labels,config.EMBEDDING_SIZE,spk_num_total+config.UNK_SPK_SUPP).cuda()
    print mix_hidden_layer_3d
    print mix_speech_classifier
    print mix_speech_multiEmbedding
    att_layer=ATTENTION(config.EMBEDDING_SIZE,'dot').cuda()
    att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'dot').cuda()
    adjust_layer=ADDJUST(2*config.HIDDEN_UNITS,config.EMBEDDING_SIZE)
    dis_layer=Discriminator().cuda()
    print att_speech_layer
    print att_speech_layer.mode
    print adjust_layer
    print dis_layer
    lr_data=0.0002
    optimizer = torch.optim.Adam([{'params':mix_hidden_layer_3d.parameters()},
                                 {'params':mix_speech_multiEmbedding.parameters()},
                                 {'params':mix_speech_classifier.parameters()},
                                 {'params':adjust_layer.parameters()},
                                 {'params':att_speech_layer.parameters()},
                                 {'params':dis_layer.parameters()},
                                 ], lr=lr_data)
    if 1 and config.Load_param:
        class_dict=torch.load('params/param_speech_2mix3lstm_best',map_location={'cuda:3':'cuda:0'})
        for key in class_dict.keys():
            if 'cnn' in key:
                class_dict.pop(key)
        mix_speech_classifier.load_state_dict(class_dict)
        # 底下四个是TDAA-basic最强版本
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mixdotadjust4lstmdot_WSJ0_hidden3d_125',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mixdotadjust4lstmdot_WSJ0_emblayer_125',map_location={'cuda:1':'cuda:0'}))
        # att_speech_layer.load_state_dict(torch.load('params/param_mixdotadjust4lstmdot_WSJ0_attlayer_125',map_location={'cuda:1':'cuda:0'}))
        # adjust_layer.load_state_dict(torch.load('params/param_mixdotadjust4lstmdot_WSJ0_adjlayer_125',map_location={'cuda:1':'cuda:0'}))

        #加入dis-ss的结果
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mixdotadjust4lstmdotdis_3434.13436424_hidden3d_395',map_location={'cuda:2':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mixdotadjust4lstmdotdis_3434.13436424_emblayer_395',map_location={'cuda:2':'cuda:0'}))
        # att_speech_layer.load_state_dict(torch.load('params/param_mixdotadjust4lstmdotdis_3434.13436424_attlayer_395',map_location={'cuda:2':'cuda:0'}))
        # adjust_layer.load_state_dict(torch.load('params/param_mixdotadjust4lstmdotdis_3434.13436424_adjlayer_395',map_location={'cuda:2':'cuda:0'}))

        #加入dis-sp的结果
        mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mixdotadjust4lstmdotdissp_33401_hidden3d_185',map_location={'cuda:1':'cuda:0'}))
        mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mixdotadjust4lstmdotdissp_33401_emblayer_185',map_location={'cuda:1':'cuda:0'}))
        att_speech_layer.load_state_dict(torch.load('params/param_mixdotadjust4lstmdotdissp_33401_attlayer_185',map_location={'cuda:1':'cuda:0'}))
        adjust_layer.load_state_dict(torch.load('params/param_mixdotadjust4lstmdotdissp_33401_adjlayer_185',map_location={'cuda:1':'cuda:0'}))
    loss_func = torch.nn.MSELoss()  # the target label is NOT an one-hotted
    loss_multi_func = torch.nn.MSELoss()  # the target label is NOT an one-hotted
    # loss_multi_func = torch.nn.L1Loss()  # the target label is NOT an one-hotted
    loss_dis_class=torch.nn.MSELoss()

    lrs.send({
        'title': 'TDAA classifier',
        'batch_size':config.BATCH_SIZE,
        'batch_total':batch_total,
        'epoch_size':config.EPOCH_SIZE,
        'loss func':loss_func.__str__(),
        'initial lr':lr_data
    })

    print '''Begin to calculate.'''
    for epoch_idx in range(1):
        if epoch_idx%10==0:
            for ee in optimizer.param_groups:
                if ee['lr']>=1e-7:
                    ee['lr']/=2
                lr_data=ee['lr']
        lrs.send('lr',lr_data)
        if epoch_idx>0:
            print 'SDR_SUM (len:{}) for epoch {} : '.format(SDR_SUM.shape,epoch_idx-1,SDR_SUM.mean())
        SDR_SUM=np.array([])
        eval_data_gen=prepare_data('once','valid')
        # eval_data_gen=prepare_data('once','test')
        while 1 and True:
            print '\n'
            eval_data=eval_data_gen.next()
            if eval_data==False:
                break #如果这个epoch的生成器没有数据了,直接进入下一个epoch

            now_data=eval_data['mix_feas']
            top_k_num=3
            # while True:

            candidates=[]
            predict_multi_map=np.zeros([config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre],dtype=np.float32)
            for ____ in range(3):
                print 'Recu step:',____
                out_this_step,spk_this_step=model_step_output(now_data,mix_speech_classifier,mix_hidden_layer_3d,\
                          mix_speech_multiEmbedding,adjust_layer,att_speech_layer,\
                          dict_spk2idx,dict_idx2spk,num_labels,mix_speech_len,speech_fre)
                out_this_step=out_this_step[0].data.cpu().numpy()
                predict_multi_map[0,____]=out_this_step
                now_data=now_data-out_this_step
                candidates.append(spk_this_step)

            y_multi_map=np.zeros([config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre],dtype=np.float32)
            batch_spk_multi_dict=eval_data['multi_spk_fea_list']
            if test_mode:
                for iiii in range(config.BATCH_SIZE):
                    y_multi_map[iiii]=np.array(batch_spk_multi_dict[iiii].values())
            y_multi_map= Variable(torch.from_numpy(y_multi_map)).cuda()

            if 0: #这个是只利用推断出来的spk,回去做分离
                print 'Recu only for spks.'
                top_mask=torch.zeros(num_labels)
                for jjj in candidates:
                    top_mask[int(jjj[0])]=1
                top_mask=top_mask.view(1,num_labels)
                try:
                    ccc=eval_bss(top_mask,eval_data,mix_hidden_layer_3d,adjust_layer, mix_speech_classifier, mix_speech_multiEmbedding, att_speech_layer,
                         loss_multi_func, dict_spk2idx, dict_idx2spk, num_labels, mix_speech_len, speech_fre)
                    SDR_SUM = np.append(SDR_SUM, ccc)
                except:
                    pass
            else:
                print 'Recu for spks and maps.'
                predict_multi_map=Variable(torch.from_numpy(predict_multi_map)).cuda()
                try:
                    bss_eval(predict_multi_map,y_multi_map,2,dict_idx2spk,eval_data)
                    SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output2/', 2))
                except:
                    pass
            if SDR_SUM[-3:].mean()>8:
                raw_input()

            print 'SDR_aver_now:',SDR_SUM.mean()

        print 'SDR_SUM (len:{}) for epoch eval : '.format(SDR_SUM.shape)
        print '#'*40
Ejemplo n.º 11
0
def eval_bss(candidates,eval_data,mix_hidden_layer_3d,adjust_layer,mix_speech_classifier,mix_speech_multiEmbedding,att_speech_layer,
             loss_multi_func,dict_spk2idx,dict_idx2spk,num_labels,mix_speech_len,speech_fre):
    for i in [mix_speech_multiEmbedding,adjust_layer,mix_speech_classifier,mix_hidden_layer_3d,att_speech_layer]:
        i.evaling=False
    fea_now=eval_data['mix_feas']
    while True:
        '''混合语音len,fre,Emb 3D表示层'''
        mix_speech_hidden,mix_tmp_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(fea_now)).cuda())
        # 暂时关掉video部分,因为s2 s3 s4 的视频数据不全暂时

        '''Speech self Sepration 语音自分离部分'''
        # mix_speech_output=mix_speech_classifier(Variable(torch.from_numpy(fea_now)).cuda())

        if test_mode:
            num_labels=2
            alpha0=-0.5
        else:
            alpha0=0.5
        # top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=alpha0,top_k=num_labels) #torch.Float型的
        top_k_mask_mixspeech=candidates #torch.Float型的
        top_k_mask_idx=[np.where(line==1)[0] for line in top_k_mask_mixspeech.numpy()]
        print 'Predict spk list:',print_spk_name(dict_idx2spk,top_k_mask_idx)
        mix_speech_multiEmbs=mix_speech_multiEmbedding(top_k_mask_mixspeech,top_k_mask_idx) # bs*num_labels(最多混合人个数)×Embedding的大小
        mix_adjust=adjust_layer(mix_tmp_hidden,mix_speech_multiEmbs)
        mix_speech_multiEmbs=mix_adjust+mix_speech_multiEmbs

        assert len(top_k_mask_idx[0])==len(top_k_mask_idx[-1])
        top_k_num=len(top_k_mask_idx[0])

        #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
        #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了 
        mix_speech_hidden_5d=mix_speech_hidden.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
        mix_speech_hidden_5d=mix_speech_hidden_5d.expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre,config.EMBEDDING_SIZE).contiguous()
        mix_speech_hidden_5d_last=mix_speech_hidden_5d.view(-1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
        # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
        # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'dot').cuda()
        att_multi_speech=att_speech_layer(mix_speech_hidden_5d_last,mix_speech_multiEmbs.view(-1,config.EMBEDDING_SIZE))
        # print att_multi_speech.size()
        att_multi_speech=att_multi_speech.view(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre) # bs,num_labels,len,fre这个东西
        # print att_multi_speech.size()
        multi_mask=att_multi_speech
        # top_k_mask_mixspeech_multi=top_k_mask_mixspeech.view(config.BATCH_SIZE,top_k_num,1,1).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre)
        # multi_mask=multi_mask*Variable(top_k_mask_mixspeech_multi).cuda()

        x_input_map=Variable(torch.from_numpy(fea_now)).cuda()
        # print x_input_map.size()
        x_input_map_multi=x_input_map.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre)
        # predict_multi_map=multi_mask*x_input_map_multi
        predict_multi_map=multi_mask*x_input_map_multi


        y_multi_map=np.zeros([config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre],dtype=np.float32)
        batch_spk_multi_dict=eval_data['multi_spk_fea_list']
        if test_mode:
            for iiii in range(config.BATCH_SIZE):
                y_multi_map[iiii]=np.array(batch_spk_multi_dict[iiii].values())
        else:
            for idx,sample in enumerate(batch_spk_multi_dict):
                y_idx=sorted([dict_spk2idx[spk] for spk in sample.keys()])
                if not test_mode:
                    assert y_idx==list(top_k_mask_idx[idx])
                for jdx,oo in enumerate(y_idx):
                    y_multi_map[idx,jdx]=sample[dict_idx2spk[oo]]
        y_multi_map= Variable(torch.from_numpy(y_multi_map)).cuda()

        loss_multi_speech=loss_multi_func(predict_multi_map,y_multi_map)

        #各通道和为1的loss部分,应该可以更多的带来差异
        y_sum_map=Variable(torch.ones(config.BATCH_SIZE,mix_speech_len,speech_fre)).cuda()
        predict_sum_map=torch.sum(multi_mask,1)
        loss_multi_sum_speech=loss_multi_func(predict_sum_map,y_sum_map)
        # loss_multi_speech=loss_multi_speech #todo:以后可以研究下这个和为1的效果对比一下,暂时直接MSE效果已经很不错了。
        print 'loss 1 eval, losssum eval : ',loss_multi_speech.data.cpu().numpy(),loss_multi_sum_speech.data.cpu().numpy()
        lrs.send('loss mask eval:',loss_multi_speech.data.cpu()[0])
        lrs.send('loss sum eval:',loss_multi_sum_speech.data.cpu()[0])
        loss_multi_speech=loss_multi_speech+0.5*loss_multi_sum_speech
        print 'evaling multi-abs norm this eval batch:',torch.abs(y_multi_map-predict_multi_map).norm().data.cpu().numpy()
        print 'loss:',loss_multi_speech.data.cpu().numpy()
        bss_eval(predict_multi_map,y_multi_map,top_k_mask_idx,dict_idx2spk,eval_data)
        return bss_test.cal('batch_output2/',2)
Ejemplo n.º 12
0
def eval(epoch, test_or_valid='valid'):
    # config.batch_size=1
    global updates, model
    model.eval()
    # print '\n\n测试的时候请设置config里的batch_size为1!!!please set the batch_size as 1'
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    print(('Test or valid:', test_or_valid))
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX,
                                 config.MAX_MIX)
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    # for iii in range(2000):
    while True:
        print(('-' * 30))
        eval_data = next(eval_data_gen)
        if eval_data == False:
            print(('SDR_aver_eval_epoch:', SDR_SUM.mean()))
            print(('SDRi_aver_eval_epoch:', SDRi_SUM.mean()))
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_complex_two_channel'])
                       )  # bs,T,F,2 both real and imag values
        raw_tgt = eval_data['batch_order']
        feas_tgt = models.rank_feas(
            raw_tgt,
            eval_data['multi_spk_wav_list'])  # 这里是目标的图谱,bs*Topk,time_len

        padded_mixture, mixture_lengths, padded_source = eval_data['tas_zip']
        padded_mixture = torch.from_numpy(padded_mixture).float()
        mixture_lengths = torch.from_numpy(mixture_lengths)
        padded_source = torch.from_numpy(padded_source).float()

        padded_mixture = padded_mixture.cuda().transpose(0, 1)
        mixture_lengths = mixture_lengths.cuda()
        padded_source = padded_source.cuda()

        # 要保证底下这几个都是longTensor(长整数)
        tgt = Variable(
            torch.from_numpy(
                np.array([[0, 1, 2, 102] for __ in range(config.batch_size)],
                         dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。

        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor([
                len(one_spk) for one_spk in eval_data['multi_spk_fea_list']
            ])).unsqueeze(0)

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()

        model.zero_grad()
        if config.use_center_loss:
            center_loss.zero_grad()

        multi_mask_real, multi_mask_imag, enc_attn_list = model(
            src, src_len, tgt, tgt_len,
            dict_spk2idx)  # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
        multi_mask_real = multi_mask_real.transpose(0, 1)
        multi_mask_imag = multi_mask_imag.transpose(0, 1)
        src_real = src[:, :, :, 0].transpose(0, 1)  # bs,T,F
        src_imag = src[:, :, :, 1].transpose(0, 1)  # bs,T,F
        print('mask size for real/imag:',
              multi_mask_real.size())  # bs,topk,T,F, 已经压缩过了
        print('mixture size for real/imag:', src_real.size())  # bs,T,F

        predicted_maps0_real = multi_mask_real[:,
                                               0] * src_real - multi_mask_imag[:,
                                                                               0] * src_imag  #bs,T,F
        predicted_maps0_imag = multi_mask_real[:,
                                               0] * src_imag + multi_mask_imag[:,
                                                                               0] * src_real  #bs,T,F
        predicted_maps1_real = multi_mask_real[:,
                                               1] * src_real - multi_mask_imag[:,
                                                                               1] * src_imag  #bs,T,F
        predicted_maps1_imag = multi_mask_real[:,
                                               1] * src_imag + multi_mask_imag[:,
                                                                               1] * src_real  #bs,T,F

        stft_matrix_spk0 = torch.cat((predicted_maps0_real.unsqueeze(-1),
                                      predicted_maps0_imag.unsqueeze(-1)),
                                     3).transpose(1, 2)  # bs,F,T,2
        stft_matrix_spk1 = torch.cat((predicted_maps1_real.unsqueeze(-1),
                                      predicted_maps1_imag.unsqueeze(-1)),
                                     3).transpose(1, 2)  # bs,F,T,2
        wav_spk0 = models.istft_irfft(stft_matrix_spk0,
                                      length=config.MAX_LEN,
                                      hop_length=config.FRAME_SHIFT,
                                      win_length=config.FRAME_LENGTH,
                                      window='hann')
        wav_spk1 = models.istft_irfft(stft_matrix_spk1,
                                      length=config.MAX_LEN,
                                      hop_length=config.FRAME_SHIFT,
                                      win_length=config.FRAME_LENGTH,
                                      window='hann')
        predict_wav = torch.cat((wav_spk0.unsqueeze(1), wav_spk1.unsqueeze(1)),
                                1)  # bs,topk,time_len
        if 1 and len(opt.gpus) > 1:
            ss_loss, pmt_list, max_snr_idx, *__ = model.module.separation_tas_loss(
                padded_mixture, predict_wav, padded_source, mixture_lengths)
        else:
            ss_loss, pmt_list, max_snr_idx, *__ = model.separation_tas_loss(
                padded_mixture, predict_wav, padded_source, mixture_lengths)

        best_pmt = [
            list(pmt_list[int(mm)].data.cpu().numpy()) for mm in max_snr_idx
        ]
        print('loss for SS,this batch:', ss_loss.cpu().item())
        print('best perms for this batch:', best_pmt)
        writer.add_scalars('scalar/loss', {'ss_loss': ss_loss.cpu().item()},
                           updates)
        lera.log({
            'ss_loss_' + test_or_valid: ss_loss.cpu().item(),
        })
        writer.add_scalars('scalar/loss',
                           {'ss_loss_' + test_or_valid: ss_loss.cpu().item()},
                           updates + batch_idx)
        del ss_loss
        # if batch_idx>10:
        #     break

        if False:  #this part is to test the checkpoints sequencially.
            batch_idx += 1
            if batch_idx % 100 == 0:
                updates = updates + 1000
                opt.restore = '/data1/shijing_data/2020-02-14-04:58:17/Transformer_PIT_{}.pt'.format(
                    updates)
                print('loading checkpoint...\n', opt.restore)
                checkpoints = torch.load(opt.restore)
                model.module.load_state_dict(checkpoints['model'])
                break
            continue
        # '''''
        if 1 and batch_idx <= (500 / config.batch_size):
            utils.bss_eval_tas(config,
                               predict_wav,
                               eval_data['multi_spk_fea_list'],
                               raw_tgt,
                               eval_data,
                               dst=log_path + 'batch_output')
            sdr_aver_batch, snri_aver_batch = bss_test.cal(log_path +
                                                           'batch_output/')
            lera.log({'SDR sample': sdr_aver_batch})
            lera.log({'SI-SNRi sample': snri_aver_batch})
            writer.add_scalars('scalar/loss', {
                'SDR_sample': sdr_aver_batch,
                'SDRi_sample': snri_aver_batch
            }, updates)
            SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
            SDRi_SUM = np.append(SDRi_SUM, snri_aver_batch)
            print(('SDR_aver_now:', SDR_SUM.mean()))
            print(('SNRi_aver_now:', SDRi_SUM.mean()))

        batch_idx += 1
        if batch_idx > 100:
            break
        result = utils.eval_metrics(reference, candidate, dict_spk2idx,
                                    log_path)
        print((
            'hamming_loss: %.8f | micro_f1: %.4f |recall: %.4f | precision: %.4f'
            % (
                result['hamming_loss'],
                result['micro_f1'],
                result['micro_recall'],
                result['micro_precision'],
            )))
Ejemplo n.º 13
0
def train(epoch):
    global e, updates, total_loss, start_time, report_total, report_correct, total_loss_sgm, total_loss_ss
    e = epoch
    model.train()
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])

    if updates <= config.warmup:  #如果不在warm阶段就正常规划
        pass
    elif config.schedule and scheduler.get_lr()[0] > 4e-5:
        scheduler.step()
        print(
            ("Decaying learning rate to %g" % scheduler.get_lr()[0], updates))
        lera.log({
            'lr': [group['lr'] for group in optim.optimizer.param_groups][0],
        })

    if opt.model == 'gated':
        model.current_epoch = epoch

    train_data_gen = prepare_data('once', 'train')
    while True:
        if updates <= config.warmup:  # 如果在warm就开始warmup
            tmp_lr = config.learning_rate * min(
                max(updates, 1)**(-0.5),
                max(updates, 1) * (config.warmup**(-1.5)))
            for param_group in optim.optimizer.param_groups:
                param_group['lr'] = tmp_lr
            scheduler.base_lrs = list(
                [group['lr'] for group in optim.optimizer.param_groups])
            if updates % 100 == 0:  #记录一下
                print(updates)
                print("Warmup learning rate to %g" % tmp_lr)
                lera.log({
                    'lr':
                    [group['lr'] for group in optim.optimizer.param_groups][0],
                })

        train_data = next(train_data_gen)
        if train_data == False:
            print(('SDR_aver_epoch:', SDR_SUM.mean()))
            print(('SDRi_aver_epoch:', SDRi_SUM.mean()))
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch

        src = Variable(torch.from_numpy(train_data['mix_complex_two_channel'])
                       )  # bs,T,F,2 both real and imag values
        raw_tgt = train_data['batch_order']
        feas_tgt = models.rank_feas(
            raw_tgt,
            train_data['multi_spk_wav_list'])  # 这里是目标的图谱,bs*Topk,time_len

        padded_mixture, mixture_lengths, padded_source = train_data['tas_zip']
        padded_mixture = torch.from_numpy(padded_mixture).float()
        mixture_lengths = torch.from_numpy(mixture_lengths)
        padded_source = torch.from_numpy(padded_source).float()

        padded_mixture = padded_mixture.cuda().transpose(0, 1)
        mixture_lengths = mixture_lengths.cuda()
        padded_source = padded_source.cuda()

        # 要保证底下这几个都是longTensor(长整数)
        tgt_max_len = config.MAX_MIX + 2  # with bos and eos.
        tgt = Variable(
            torch.from_numpy(
                np.array(
                    [[0] + [dict_spk2idx[spk] for spk in spks] +
                     (tgt_max_len - len(spks) - 1) * [dict_spk2idx['<EOS>']]
                     for spks in raw_tgt],
                    dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。
        # tgt = Variable(torch.from_numpy(np.array([[0,1,2,102] for __ in range(config.batch_size)], dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。
        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor([
                len(one_spk) for one_spk in train_data['multi_spk_fea_list']
            ])).unsqueeze(0)

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()

        model.zero_grad()
        if config.use_center_loss:
            center_loss.zero_grad()

        multi_mask_real, multi_mask_imag, enc_attn_list = model(
            src, src_len, tgt, tgt_len,
            dict_spk2idx)  # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
        multi_mask_real = multi_mask_real.transpose(0, 1)
        multi_mask_imag = multi_mask_imag.transpose(0, 1)
        src_real = src[:, :, :, 0].transpose(0, 1)  # bs,T,F
        src_imag = src[:, :, :, 1].transpose(0, 1)  # bs,T,F
        print('mask size for real/imag:',
              multi_mask_real.size())  # bs,topk,T,F, 已经压缩过了
        print('mixture size for real/imag:', src_real.size())  # bs,T,F

        predicted_maps0_real = multi_mask_real[:,
                                               0] * src_real - multi_mask_imag[:,
                                                                               0] * src_imag  #bs,T,F
        predicted_maps0_imag = multi_mask_real[:,
                                               0] * src_imag + multi_mask_imag[:,
                                                                               0] * src_real  #bs,T,F
        predicted_maps1_real = multi_mask_real[:,
                                               1] * src_real - multi_mask_imag[:,
                                                                               1] * src_imag  #bs,T,F
        predicted_maps1_imag = multi_mask_real[:,
                                               1] * src_imag + multi_mask_imag[:,
                                                                               1] * src_real  #bs,T,F

        stft_matrix_spk0 = torch.cat((predicted_maps0_real.unsqueeze(-1),
                                      predicted_maps0_imag.unsqueeze(-1)),
                                     3).transpose(1, 2)  # bs,F,T,2
        stft_matrix_spk1 = torch.cat((predicted_maps1_real.unsqueeze(-1),
                                      predicted_maps1_imag.unsqueeze(-1)),
                                     3).transpose(1, 2)  # bs,F,T,2
        wav_spk0 = models.istft_irfft(stft_matrix_spk0,
                                      length=config.MAX_LEN,
                                      hop_length=config.FRAME_SHIFT,
                                      win_length=config.FRAME_LENGTH,
                                      window='hann')
        wav_spk1 = models.istft_irfft(stft_matrix_spk1,
                                      length=config.MAX_LEN,
                                      hop_length=config.FRAME_SHIFT,
                                      win_length=config.FRAME_LENGTH,
                                      window='hann')
        predict_wav = torch.cat((wav_spk0.unsqueeze(1), wav_spk1.unsqueeze(1)),
                                1)  # bs,topk,time_len
        if 1 and len(opt.gpus) > 1:
            ss_loss, pmt_list, max_snr_idx, *__ = model.module.separation_tas_loss(
                padded_mixture, predict_wav, padded_source, mixture_lengths)
        else:
            ss_loss, pmt_list, max_snr_idx, *__ = model.separation_tas_loss(
                padded_mixture, predict_wav, padded_source, mixture_lengths)

        best_pmt = [
            list(pmt_list[int(mm)].data.cpu().numpy()) for mm in max_snr_idx
        ]
        print('loss for SS,this batch:', ss_loss.cpu().item())
        print('best perms for this batch:', best_pmt)
        writer.add_scalars('scalar/loss', {'ss_loss': ss_loss.cpu().item()},
                           updates)

        loss = ss_loss
        loss.backward()

        total_loss_ss += ss_loss.cpu().item()
        lera.log({
            'ss_loss': ss_loss.cpu().item(),
        })

        if epoch > 20 and updates > 5 and updates % config.eval_interval in [
                0, 1, 2, 3, 4
        ]:
            utils.bss_eval_tas(config,
                               predict_wav,
                               train_data['multi_spk_fea_list'],
                               raw_tgt,
                               train_data,
                               dst=log_path + 'batch_output')
            sdr_aver_batch, snri_aver_batch = bss_test.cal(log_path +
                                                           'batch_output/')
            lera.log({'SDR sample': sdr_aver_batch})
            lera.log({'SI-SNRi sample': snri_aver_batch})
            writer.add_scalars('scalar/loss', {
                'SDR_sample': sdr_aver_batch,
                'SDRi_sample': snri_aver_batch
            }, updates)
            SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
            SDRi_SUM = np.append(SDRi_SUM, snri_aver_batch)
            print(('SDR_aver_now:', SDR_SUM.mean()))
            print(('SNRi_aver_now:', SDRi_SUM.mean()))

            # Heatmap here
            # n_layer个 (head*bs) x lq x dk
            '''
            import matplotlib.pyplot as plt
            ax = plt.gca()
            ax.invert_yaxis()

            raw_src=models.rank_feas(raw_tgt, train_data['multi_spk_fea_list'])
            att_idx=1
            att = enc_attn_list[-1].view(config.trans_n_head,config.batch_size,mix_speech_len,mix_speech_len).data.cpu().numpy()[:,att_idx]
            for head in range(config.trans_n_head):
                xx=att[head]
                plt.matshow(xx, cmap=plt.cm.hot, vmin=0,vmax=0.05)
                plt.colorbar()
                plt.savefig(log_path+'batch_output/'+'head_{}.png'.format(head))
            plt.matshow(raw_src[att_idx*2+0].transpose(0,1), cmap=plt.cm.hot, vmin=0,vmax=2)
            plt.colorbar()
            plt.savefig(log_path+'batch_output/'+'source0.png')
            plt.matshow(raw_src[att_idx*2+1].transpose(0,1), cmap=plt.cm.hot, vmin=0,vmax=2)
            plt.colorbar()
            plt.savefig(log_path+'batch_output/'+'source1.png')
            1/0
            '''

        total_loss += loss.cpu().item()
        optim.step()

        updates += 1
        if updates % 30 == 0:
            logging(
                "time: %6.3f, epoch: %3d, updates: %8d, train loss this batch: %6.3f,ss loss: %6.6f\n"
                % (time.time() - start_time, epoch, updates, loss,
                   total_loss_ss / 30.0))
            total_loss_sgm, total_loss_ss = 0, 0

        # continue

        if 1 and updates % config.save_interval == 1:
            save_model(log_path + 'Transformer_PIT_2ch_{}.pt'.format(updates))

        if 0 and updates > 0 and updates % config.eval_interval == 3:  #建议至少跑几个epoch再进行测试,否则模型还没学到东西,会有很多问题。
            logging(
                "time: %6.3f, epoch: %3d, updates: %8d, train loss: %6.5f\n" %
                (time.time() - start_time, epoch, updates,
                 total_loss / config.eval_interval))
            print(('evaluating after %d updates...\r' % updates))
            eval(epoch, 'valid')  # eval的时候batch_size会变成1
            eval(epoch, 'test')  # eval的时候batch_size会变成1

            model.train()
            total_loss = 0
            start_time = 0
            report_total = 0
            report_correct = 0
Ejemplo n.º 14
0
def eval(epoch):
    model.eval()
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    test_or_valid = 'valid'
    print 'Test or valid:', test_or_valid
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX,
                                 config.MAX_MIX)
    # for raw_src, src, src_len, raw_tgt, tgt, tgt_len in validloader:
    SDR_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    while True:
        # for ___ in range(2):
        print '-' * 30
        eval_data = eval_data_gen.next()
        if eval_data == False:
            break  #如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_feas']))

        raw_tgt = [
            sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']
        ]
        top_k = len(raw_tgt[0])
        # 要保证底下这几个都是longTensor(长整数)
        # tgt = Variable(torch.from_numpy(np.array([[0]+[dict_spk2idx[spk] for spk in spks]+[dict_spk2idx['<EOS>']] for spks in raw_tgt],dtype=np.int))).transpose(0,1) #转换成数字,然后前后加开始和结束符号。
        tgt = Variable(torch.ones(
            top_k + 2, config.batch_size))  # 这里随便给一个tgt,为了测试阶段tgt的名字无所谓其实。

        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            len(eval_data['multi_spk_fea_list'][0])).unsqueeze(0)
        feas_tgt = models.rank_feas(raw_tgt,
                                    eval_data['multi_spk_fea_list'])  #这里是目标的图谱
        if config.WFM:
            tmp_size = feas_tgt.size()
            assert len(tmp_size) == 4
            feas_tgt_sum = torch.sum(feas_tgt, dim=1, keepdim=True)
            feas_tgt_sum_square = (feas_tgt_sum *
                                   feas_tgt_sum).expand(tmp_size)
            feas_tgt_square = feas_tgt * feas_tgt
            WFM_mask = feas_tgt_square / feas_tgt_sum_square

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()
            if config.WFM:
                WFM_mask = WFM_mask.cuda()

        if config.buffer_size or config.buffer_shift:  # first convet to realtime batches
            assert src.size()[1] == 1
            left_padding = Variable(
                torch.zeros(config.buffer_size,
                            src.size()[1],
                            src.size()[-1]).cuda())
            src = torch.cat((left_padding, src), dim=0)

            split_idx = 0
            src_new = Variable(
                torch.zeros(config.buffer_size + config.buffer_shift,
                            mix_speech_len / config.buffer_shift + 1,
                            src.size()[-1]).cuda())
            batch_counter = 0
            while True:
                print 'split_idx at:', split_idx
                split_len = config.buffer_size + config.buffer_shift  # the len of every split
                if split_idx + split_len > src.size(
                )[0]:  # if pass the right length
                    print 'Need to add right padding with len:', (
                        split_idx + split_len) - src.size()[0]
                    right_padding = Variable(
                        torch.zeros((split_idx + split_len) - src.size()[0],
                                    src.size()[1],
                                    src.size()[-1]).cuda())
                    src = torch.cat((src, right_padding), dim=0)
                    src_split = src[split_idx:(split_idx + split_len)]
                    src_new[:, batch_counter] = src_split
                    break
                src_split = src[split_idx:(split_idx + split_len)]
                src_new[:, batch_counter] = src_split
                split_idx += config.buffer_shift
                batch_counter += 1
            assert batch_counter + 1 == src_new.size()[1]
            src_len[0] = config.buffer_shift + config.buffer_size
            src_len = src_len.expand(1, src_new.size()[1])

        try:
            if 1 and len(opt.gpus) > 1:
                # samples, alignment = model.module.sample(src, src_len)
                samples, alignment, hiddens, predicted_masks = model.module.beam_sample(
                    src_new,
                    src_len,
                    dict_spk2idx,
                    tgt,
                    beam_size=config.beam_size)
            else:
                samples, alignment, hiddens, predicted_masks = model.beam_sample(
                    src_new,
                    src_len,
                    dict_spk2idx,
                    tgt,
                    beam_size=config.beam_size)
                # samples, alignment, hiddens, predicted_masks = model.beam_sample(src, src_len, dict_spk2idx, tgt, beam_size=config.beam_size)
        except Exception, info:
            print '**************Error occurs here************:', info
            continue

        if config.top1:
            predicted_masks = torch.cat([predicted_masks, 1 - predicted_masks],
                                        1)

        if config.buffer_size and config.buffer_shift:  # then recover the whole maps
            # masks:[7,topk,buffer_size+buffer_shift,fre]
            masks_recover = Variable(
                torch.zeros(1, predicted_masks.size(1), mix_speech_len,
                            speech_fre).cuda())
            recover_idx = 0
            for batch_counter in range(predicted_masks.size(0)):
                if not batch_counter == predicted_masks.size(0) - 1:
                    masks_recover[:, :, recover_idx:recover_idx +
                                  config.buffer_shift] = predicted_masks[
                                      batch_counter, :,
                                      -1 * config.buffer_shift:]
                else:  # the last shift
                    assert mix_speech_len - recover_idx == config.buffer_shift - right_padding.size(
                        0)
                    masks_recover[:, :, recover_idx:] = predicted_masks[
                        batch_counter, :,
                        -1 * config.buffer_shift:(-1 * right_padding.size(0))]
                recover_idx += config.buffer_shift
            predicted_masks = masks_recover
            src = Variable(torch.from_numpy(eval_data['mix_feas'])).transpose(
                0, 1).cuda()

        # '''
        # expand the raw mixed-features to topk channel.
        src = src.transpose(0, 1)
        siz = src.size()
        assert len(siz) == 3
        topk = feas_tgt.size()[1]
        x_input_map_multi = torch.unsqueeze(src,
                                            1).expand(siz[0], topk, siz[1],
                                                      siz[2])
        if config.WFM:
            feas_tgt = x_input_map_multi.data * WFM_mask
        '''
        if 1 and len(opt.gpus) > 1:
            ss_loss = model.module.separation_loss(x_input_map_multi, predicted_masks, feas_tgt,Var)
        else:
            ss_loss = model.separation_loss(x_input_map_multi, predicted_masks, feas_tgt,None)
        print 'loss for ss,this batch:',ss_loss.data[0]
        lera.log({
            'ss_loss_'+test_or_valid: ss_loss.data[0],
        })

        del ss_loss,hiddens

        # ''' ''
        if batch_idx <= (500 / config.batch_size
                         ):  #only the former batches counts the SDR
            # x_input_map_multi=x_input_map_multi[:,:,:config.buffer_shift]
            predicted_maps = predicted_masks * x_input_map_multi
            # predicted_maps=Variable(feas_tgt)
            utils.bss_eval(config,
                           predicted_maps,
                           eval_data['multi_spk_fea_list'],
                           raw_tgt,
                           eval_data,
                           dst='batch_outputwaddd')
            del predicted_maps, predicted_masks, x_input_map_multi
            SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_outputwaddd/'))
            print 'SDR_aver_now:', SDR_SUM.mean()
            lera.log({'SDR sample': SDR_SUM.mean()})
            # raw_input('Press any key to continue......')
        elif batch_idx == (500 / config.batch_size) + 1 and SDR_SUM.mean(
        ) > best_SDR:  #only record the best SDR once.
            print 'Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean())
            best_SDR = SDR_SUM.mean()
            # save_model(log_path+'checkpoint_bestSDR{}.pt'.format(best_SDR))

        # '''
        candidate += [
            convertToLabels(dict_idx2spk, s, dict_spk2idx['<EOS>'])
            for s in samples
        ]
        # source += raw_src
        reference += raw_tgt
        print 'samples:', samples
        print 'can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:],
                                        reference[-1 * config.batch_size:])
        alignments += [align for align in alignment]
        batch_idx += 1
Ejemplo n.º 15
0
def eval(epoch, test_or_valid='train'):
    # config.batch_size=1
    global updates, model
    model.eval()
    # print '\n\n测试的时候请设置config里的batch_size为1!!!please set the batch_size as 1'
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    print(('Test or valid:', test_or_valid))
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX,
                                 config.MAX_MIX)
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    # for iii in range(2000):
    while True:
        print(('-' * 30))
        eval_data = next(eval_data_gen)
        if eval_data == False:
            print(('SDR_aver_eval_epoch:', SDR_SUM.mean()))
            print(('SDRi_aver_eval_epoch:', SDRi_SUM.mean()))
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch

        src = Variable(torch.from_numpy(eval_data['mix_feas']))
        # raw_tgt = [spk.keys() for spk in eval_data['multi_spk_fea_list']]
        # raw_tgt = [sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']]
        raw_tgt = eval_data['batch_order']
        feas_tgt = models.rank_feas(
            raw_tgt,
            eval_data['multi_spk_wav_list'])  # 这里是目标的图谱,bs*Topk,time_len

        padded_mixture, mixture_lengths, padded_source = eval_data['tas_zip']
        padded_mixture = torch.from_numpy(padded_mixture).float()
        mixture_lengths = torch.from_numpy(mixture_lengths)
        padded_source = torch.from_numpy(padded_source).float()

        padded_mixture = padded_mixture.cuda().transpose(0, 1)
        mixture_lengths = mixture_lengths.cuda()
        padded_source = padded_source.cuda()

        # 要保证底下这几个都是longTensor(长整数)
        tgt_max_len = config.MAX_MIX + 2  # with bos and eos.
        # tgt = Variable(torch.from_numpy(np.array(
        #     [[0] + [dict_spk2idx[spk] for spk in spks] + (tgt_max_len - len(spks) - 1) * [dict_spk2idx['<EOS>']] for
        #      spks in raw_tgt], dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。
        tgt = Variable(
            torch.from_numpy(
                np.array([[0, 1, 2, 102] for __ in range(config.batch_size)],
                         dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。
        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor([
                len(one_spk) for one_spk in eval_data['multi_spk_fea_list']
            ])).unsqueeze(0)

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()

        model.zero_grad()
        if config.use_center_loss:
            center_loss.zero_grad()

        multi_mask, enc_attn_list = model(
            src, src_len, tgt, tgt_len,
            dict_spk2idx)  # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
        multi_mask = multi_mask.transpose(0, 1)
        print('mask size:', multi_mask.size())  # bs,topk,T,F

        predicted_maps0_spectrogram = multi_mask[:, 0] * src.transpose(
            0, 1)  #bs,T,F
        predicted_maps1_spectrogram = multi_mask[:, 1] * src.transpose(
            0, 1)  #bs,T,F

        if True:  # Analyze the optimal assignments
            predicted_spectrogram = torch.cat([
                predicted_maps0_spectrogram.unsqueeze(1),
                predicted_maps1_spectrogram.unsqueeze(1)
            ], 1)
            feas_tgt_tmp = models.rank_feas(
                raw_tgt,
                eval_data['multi_spk_fea_list'])  # 这里是目标的图谱,bs*Topk,len,fre
            src = src.transpose(0, 1)
            siz = src.size()  # bs,T,F
            assert len(siz) == 3
            # topk_max = config.MAX_MIX  # 最多可能的topk个数
            topk_max = 2  # 最多可能的topk个数
            x_input_map_multi = torch.unsqueeze(src, 1).expand(
                siz[0], topk_max, siz[1], siz[2]).contiguous()  # bs,topk,T,F
            feas_tgt_tmp = feas_tgt_tmp.view(siz[0], -1, siz[1], siz[2])

            angle_tgt = models.rank_feas(
                raw_tgt, eval_data['multi_spk_angle_list']).view(
                    siz[0], -1, siz[1], siz[2])  # bs,topk,T,F
            angle_mix = Variable(
                torch.from_numpy(np.array(
                    eval_data['mix_angle']))).unsqueeze(1).expand(
                        siz[0], topk_max, siz[1], siz[2]).contiguous()
            ang = np.cos(angle_mix - angle_tgt)
            ang = np.clip(ang, 0, None)

            feas_tgt_tmp = feas_tgt_tmp.view(siz[0], -1, siz[1],
                                             siz[2]) * ang  # bs,topk,T,F
            feas_tgt_tmp = feas_tgt_tmp.cuda()
            del x_input_map_multi
            src = src.transpose(0, 1)
            MSE_func = nn.MSELoss().cuda()
            best_perms_this_batch = []
            for bs_idx in range(siz[0]):
                best_perms_this_sample = []
                for tt in range(siz[1]):  # 对每一帧
                    tar = feas_tgt_tmp[bs_idx, :, tt]  #topk,F
                    est = predicted_spectrogram[bs_idx, :, tt]  #topk,F
                    best_loss_mse_this_batch = -1
                    for idx, per in enumerate([[0, 1], [1, 0]]):
                        if idx == 0:
                            best_loss_mse_this_batch = MSE_func(est[per], tar)
                            perm_this_frame = per
                            predicted_spectrogram[bs_idx, :, tt] = est[per]
                        else:
                            loss = MSE_func(est[per], tar)
                            if loss <= best_loss_mse_this_batch:
                                best_loss_mse_this_batch = loss
                                perm_this_frame = per
                                predicted_spectrogram[bs_idx, :, tt] = est[per]

                    best_perms_this_sample.append(perm_this_frame)
                best_perms_this_batch.append(best_perms_this_sample)
            print(
                'different assignment ratio:',
                np.mean(np.min(
                    np.array(best_perms_this_batch).sum(1) / 751, 1)))
            # predicted_maps0_spectrogram = predicted_spectrogram[:,0]
            # predicted_maps1_spectrogram = predicted_spectrogram[:,1]

        _mix_spec = eval_data['mix_phase']  # bs,T,F,2
        angle_mix = np.angle(_mix_spec)
        predicted_maps0_real = predicted_maps0_spectrogram * torch.from_numpy(
            np.cos(angle_mix)).cuda()  # e(ix) = cosx + isin x
        predicted_maps0_imag = predicted_maps0_spectrogram * torch.from_numpy(
            np.sin(angle_mix)).cuda()  # e(ix) = cosx + isin x
        predicted_maps1_real = predicted_maps1_spectrogram * torch.from_numpy(
            np.cos(angle_mix)).cuda()  # e(ix) = cosx + isin x
        predicted_maps1_imag = predicted_maps1_spectrogram * torch.from_numpy(
            np.sin(angle_mix)).cuda()  # e(ix) = cosx + isin x

        stft_matrix_spk0 = torch.cat((predicted_maps0_real.unsqueeze(-1),
                                      predicted_maps0_imag.unsqueeze(-1)),
                                     3).transpose(1, 2)  # bs,F,T,2
        stft_matrix_spk1 = torch.cat((predicted_maps1_real.unsqueeze(-1),
                                      predicted_maps1_imag.unsqueeze(-1)),
                                     3).transpose(1, 2)  # bs,F,T,2
        wav_spk0 = models.istft_irfft(stft_matrix_spk0,
                                      length=config.MAX_LEN,
                                      hop_length=config.FRAME_SHIFT,
                                      win_length=config.FRAME_LENGTH,
                                      window='hann')
        wav_spk1 = models.istft_irfft(stft_matrix_spk1,
                                      length=config.MAX_LEN,
                                      hop_length=config.FRAME_SHIFT,
                                      win_length=config.FRAME_LENGTH,
                                      window='hann')
        predict_wav = torch.cat((wav_spk0.unsqueeze(1), wav_spk1.unsqueeze(1)),
                                1)  # bs,topk,time_len
        if 1 and len(opt.gpus) > 1:
            ss_loss, pmt_list, max_snr_idx, *__ = model.module.separation_tas_loss(
                padded_mixture, predict_wav, padded_source, mixture_lengths)
        else:
            ss_loss, pmt_list, max_snr_idx, *__ = model.separation_tas_loss(
                padded_mixture, predict_wav, padded_source, mixture_lengths)

        best_pmt = [
            list(pmt_list[int(mm)].data.cpu().numpy()) for mm in max_snr_idx
        ]
        print('loss for SS,this batch:', ss_loss.cpu().item())
        print('best perms for this batch:', best_pmt)
        writer.add_scalars('scalar/loss', {'ss_loss': ss_loss.cpu().item()},
                           updates)
        lera.log({
            'ss_loss_' + test_or_valid: ss_loss.cpu().item(),
        })
        writer.add_scalars('scalar/loss',
                           {'ss_loss_' + test_or_valid: ss_loss.cpu().item()},
                           updates + batch_idx)
        del ss_loss
        # if batch_idx>10:
        #     break

        if False:  #this part is to test the checkpoints sequencially.
            batch_idx += 1
            if batch_idx % 100 == 0:
                updates = updates + 1000
                opt.restore = '/data1/shijing_data/2020-02-14-04:58:17/Transformer_PIT_{}.pt'.format(
                    updates)
                print('loading checkpoint...\n', opt.restore)
                checkpoints = torch.load(opt.restore)
                model.module.load_state_dict(checkpoints['model'])
                break
            continue
        # '''''
        if 1 and batch_idx <= (500 / config.batch_size):
            utils.bss_eval_tas(config,
                               predict_wav,
                               eval_data['multi_spk_fea_list'],
                               raw_tgt,
                               eval_data,
                               dst=log_path + 'batch_output')
            sdr_aver_batch, snri_aver_batch = bss_test.cal(log_path +
                                                           'batch_output/')
            lera.log({'SDR sample': sdr_aver_batch})
            lera.log({'SI-SNRi sample': snri_aver_batch})
            writer.add_scalars('scalar/loss', {
                'SDR_sample': sdr_aver_batch,
                'SDRi_sample': snri_aver_batch
            }, updates)
            SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
            SDRi_SUM = np.append(SDRi_SUM, snri_aver_batch)
            print(('SDR_aver_now:', SDR_SUM.mean()))
            print(('SNRi_aver_now:', SDRi_SUM.mean()))

        batch_idx += 1
        if batch_idx > 100:
            break
        result = utils.eval_metrics(reference, candidate, dict_spk2idx,
                                    log_path)
        print((
            'hamming_loss: %.8f | micro_f1: %.4f |recall: %.4f | precision: %.4f'
            % (
                result['hamming_loss'],
                result['micro_f1'],
                result['micro_recall'],
                result['micro_precision'],
            )))
Ejemplo n.º 16
0
def main():
    print('go to model')
    print '*' * 80

    spk_global_gen = prepare_data(mode='global',
                                  train_or_test='train')  #写一个假的数据生成,可以用来写模型先
    global_para = spk_global_gen.next()
    print global_para
    spk_all_list, dict_spk2idx, dict_idx2spk, mix_speech_len, speech_fre, total_frames, spk_num_total = global_para
    del spk_global_gen
    num_labels = len(spk_all_list)

    # data_generator=prepare_data('once','train')
    # data_generator=prepare_data_fake(train_or_test='train',num_labels=num_labels) #写一个假的数据生成,可以用来写模型先

    #此处顺序是 mix_speechs.shape,mix_feas.shape,aim_fea.shape,aim_spkid.shape,query.shape
    #一个例子:(5, 17040) (5, 134, 129) (5, 134, 129) (5,) (5, 32, 400, 300, 3)
    # datasize=prepare_datasize(data_generator)
    # mix_speech_len,speech_fre,total_frames,spk_num_total,video_size=datasize
    print 'Begin to build the maim model for Multi_Modal Cocktail Problem.'
    # data=data_generator.next()

    # This part is to build the 3D mix speech embedding maps.
    mix_hidden_layer_3d = MIX_SPEECH(speech_fre, mix_speech_len).cuda()
    mix_speech_classifier = MIX_SPEECH_classifier(speech_fre, mix_speech_len,
                                                  num_labels).cuda()
    mix_speech_multiEmbedding = SPEECH_EMBEDDING(
        num_labels, config.EMBEDDING_SIZE,
        spk_num_total + config.UNK_SPK_SUPP).cuda()
    print mix_hidden_layer_3d
    print mix_speech_classifier
    # mix_speech_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(data[1])).cuda())

    # mix_speech_output=mix_speech_classifier(Variable(torch.from_numpy(data[1])).cuda())
    # 技巧:alpha0的时候,就是选出top_k,top_k很大的时候,就是选出来大于alpha的
    # top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=config.ALPHA,top_k=config.MAX_MIX)
    # top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=config.ALPHA,top_k=3)
    # print top_k_mask_mixspeech
    # mix_speech_multiEmbs=mix_speech_multiEmbedding(top_k_mask_mixspeech) # bs*num_labels(最多混合人个数)×Embedding的大小
    # mix_speech_multiEmbs=mix_speech_multiEmbedding(Variable(torch.from_numpy(top_k_mask_mixspeech),requires_grad=False).cuda()) # bs*num_labels(最多混合人个数)×Embedding的大小

    # 需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
    # 把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了
    # mix_speech_hidden_5d=mix_speech_hidden.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
    # mix_speech_hidden_5d=mix_speech_hidden_5d.expand(config.BATCH_SIZE,num_labels,mix_speech_len,speech_fre,config.EMBEDDING_SIZE).contiguous()
    # mix_speech_hidden_5d=mix_speech_hidden_5d.view(-1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
    # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
    # att_multi_speech=att_speech_layer(mix_speech_hidden_5d,mix_speech_multiEmbs.view(-1,config.EMBEDDING_SIZE))
    # print att_multi_speech.size()
    # att_multi_speech=att_multi_speech.view(config.BATCH_SIZE,num_labels,mix_speech_len,speech_fre,-1)
    # print att_multi_speech.size()

    # This part is to conduct the video inputs.
    # query_video_layer=VIDEO_QUERY(total_frames,config.VideoSize,spk_num_total).cuda()
    query_video_layer = None
    # print query_video_layer
    # query_video_output,xx=query_video_layer(Variable(torch.from_numpy(data[4])))

    # This part is to conduct the memory.
    # hidden_size=(config.HIDDEN_UNITS)
    hidden_size = (config.EMBEDDING_SIZE)
    memory = MEMORY(spk_num_total + config.UNK_SPK_SUPP, hidden_size)
    memory.register_spklist(spk_all_list)  #把spk_list注册进空的memory里面去

    # Memory function test.
    print 'memory all spkid:', memory.get_all_spkid()
    # print memory.get_image_num('Unknown_id')
    # print memory.get_video_vector('Unknown_id')
    # print memory.add_video('Unknown_id',Variable(torch.ones(300)))

    # This part is to test the ATTENTION methond from query(~) to mix_speech
    # x=torch.arange(0,24).view(2,3,4)
    # y=torch.ones([2,4])
    att_layer = ATTENTION(config.EMBEDDING_SIZE, 'align').cuda()
    att_speech_layer = ATTENTION(config.EMBEDDING_SIZE, 'align').cuda()
    # att=ATTENTION(4,'align')
    # mask=att(x,y)#bs*max_len

    # del data_generator
    # del data

    optimizer = torch.optim.Adam(
        [
            {
                'params': mix_hidden_layer_3d.parameters()
            },
            {
                'params': mix_speech_multiEmbedding.parameters()
            },
            {
                'params': mix_speech_classifier.parameters()
            },
            # {'params':query_video_layer.lstm_layer.parameters()},
            # {'params':query_video_layer.dense.parameters()},
            # {'params':query_video_layer.Linear.parameters()},
            {
                'params': att_layer.parameters()
            },
            {
                'params': att_speech_layer.parameters()
            },
            # ], lr=0.02,momentum=0.9)
        ],
        lr=0.0002)
    if 0 and config.Load_param:
        # query_video_layer.load_state_dict(torch.load('param_video_layer_19'))
        mix_speech_classifier.load_state_dict(
            torch.load('params/param_speech_multilabel_epoch249'))
        mix_hidden_layer_3d.load_state_dict(
            torch.load('params/param_mix_speech_hidden3d_220'))
        mix_speech_multiEmbedding.load_state_dict(
            torch.load('params/param_mix_speech_emblayer_220'))
        att_speech_layer.load_state_dict(
            torch.load('params/param_mix_speech_attlayer_220'))
    loss_func = torch.nn.MSELoss()  # the target label is NOT an one-hotted
    loss_multi_func = torch.nn.MSELoss(
    )  # the target label is NOT an one-hotted
    # loss_multi_func = torch.nn.L1Loss()  # the target label is NOT an one-hotted
    loss_query_class = torch.nn.CrossEntropyLoss()

    print '''Begin to calculate.'''
    for epoch_idx in range(config.MAX_EPOCH):
        if epoch_idx > 0:
            print 'SDR_SUM (len:{}) for epoch {} : '.format(
                SDR_SUM.shape, epoch_idx - 1, SDR_SUM.mean())
        SDR_SUM = np.array([])
        # print_memory_state(memory.memory)
        print 'SDR_SUM for epoch {}:{}'.format(epoch_idx - 1, SDR_SUM.mean())
        for batch_idx in range(config.EPOCH_SIZE):
            print '*' * 40, epoch_idx, batch_idx, '*' * 40
            train_data_gen = prepare_data('once', 'train')
            # train_data_gen=prepare_data('once','test')
            train_data = train_data_gen.next()
            '''混合语音len,fre,Emb 3D表示层'''
            mix_speech_hidden = mix_hidden_layer_3d(
                Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            # 暂时关掉video部分,因为s2 s3 s4 的视频数据不全暂时
            '''Speech self Sepration 语音自分离部分'''
            mix_speech_output = mix_speech_classifier(
                Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            #从数据里得到ground truth的说话人名字和vector
            y_spk_list = [
                one.keys() for one in train_data['multi_spk_fea_list']
            ]
            y_spk_gtruth, y_map_gtruth = multi_label_vector(
                y_spk_list, dict_spk2idx)
            # 如果训练阶段使用Ground truth的分离结果作为判别
            if config.Ground_truth:
                mix_speech_output = Variable(
                    torch.from_numpy(y_map_gtruth)).cuda()
                if test_all_outputchannel:  #把输入的mask改成全1,可以用来测试输出所有的channel
                    mix_speech_output = Variable(
                        torch.ones(
                            config.BATCH_SIZE,
                            num_labels,
                        ))
                    y_map_gtruth = np.ones([config.BATCH_SIZE, num_labels])

            top_k_mask_mixspeech = top_k_mask(mix_speech_output,
                                              alpha=0.5,
                                              top_k=num_labels)  #torch.Float型的
            mix_speech_multiEmbs = mix_speech_multiEmbedding(
                top_k_mask_mixspeech)  # bs*num_labels(最多混合人个数)×Embedding的大小

            #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
            #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了
            mix_speech_hidden_5d = mix_speech_hidden.view(
                config.BATCH_SIZE, 1, mix_speech_len, speech_fre,
                config.EMBEDDING_SIZE)
            mix_speech_hidden_5d = mix_speech_hidden_5d.expand(
                config.BATCH_SIZE, num_labels, mix_speech_len, speech_fre,
                config.EMBEDDING_SIZE).contiguous()
            mix_speech_hidden_5d_last = mix_speech_hidden_5d.view(
                -1, mix_speech_len, speech_fre, config.EMBEDDING_SIZE)
            # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
            att_speech_layer = ATTENTION(config.EMBEDDING_SIZE, 'dot').cuda()
            att_multi_speech = att_speech_layer(
                mix_speech_hidden_5d_last,
                mix_speech_multiEmbs.view(-1, config.EMBEDDING_SIZE))
            # print att_multi_speech.size()
            att_multi_speech = att_multi_speech.view(
                config.BATCH_SIZE, num_labels, mix_speech_len,
                speech_fre)  # bs,num_labels,len,fre这个东西
            # print att_multi_speech.size()
            multi_mask = att_multi_speech
            top_k_mask_mixspeech_multi = top_k_mask_mixspeech.view(
                config.BATCH_SIZE, num_labels, 1,
                1).expand(config.BATCH_SIZE, num_labels, mix_speech_len,
                          speech_fre)
            multi_mask = multi_mask * Variable(
                top_k_mask_mixspeech_multi).cuda()

            x_input_map = Variable(torch.from_numpy(
                train_data['mix_feas'])).cuda()
            # print x_input_map.size()
            x_input_map_multi = x_input_map.view(
                config.BATCH_SIZE, 1, mix_speech_len,
                speech_fre).expand(config.BATCH_SIZE, num_labels,
                                   mix_speech_len, speech_fre)
            predict_multi_map = multi_mask * x_input_map_multi
            if batch_idx % 100 == 0:
                print multi_mask
            # print predict_multi_map

            y_multi_map = np.zeros(
                [config.BATCH_SIZE, num_labels, mix_speech_len, speech_fre],
                dtype=np.float32)
            batch_spk_multi_dict = train_data['multi_spk_fea_list']
            for idx, sample in enumerate(batch_spk_multi_dict):
                for spk in sample.keys():
                    y_multi_map[idx, dict_spk2idx[spk]] = sample[spk]
            y_multi_map = Variable(torch.from_numpy(y_multi_map)).cuda()

            loss_multi_speech = loss_multi_func(predict_multi_map, y_multi_map)

            #各通道和为1的loss部分,应该可以更多的带来差异
            y_sum_map = Variable(
                torch.ones(config.BATCH_SIZE, mix_speech_len,
                           speech_fre)).cuda()
            predict_sum_map = torch.sum(predict_multi_map, 1)
            loss_multi_sum_speech = loss_multi_func(predict_sum_map, y_sum_map)
            loss_multi_speech = loss_multi_speech  #todo:以后可以研究下这个和为1的效果对比一下,暂时直接MSE效果已经很不错了。
            print 'loss 1, losssum : ', loss_multi_speech, loss_multi_sum_speech
            # loss_multi_speech=loss_multi_speech+0.5*loss_multi_sum_speech

            if 1 or batch_idx == config.EPOCH_SIZE - 1:
                bss_eval(predict_multi_map, y_multi_map, y_map_gtruth,
                         dict_idx2spk, train_data)
                SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output/', 2))

            print 'training multi-abs norm this batch:', torch.abs(
                y_multi_map - predict_multi_map).norm().data.cpu().numpy()
            print 'loss:', loss_multi_speech.data.cpu().numpy()
            optimizer.zero_grad()  # clear gradients for next train
            loss_multi_speech.backward()  # backpropagation, compute gradients
            optimizer.step()  # apply gradients

            if 1 and epoch_idx > 20 and epoch_idx % 10 == 0 and batch_idx == config.EPOCH_SIZE - 1:
                torch.save(
                    mix_speech_multiEmbedding.state_dict(),
                    'params/param_mix_{}_emblayer_{}'.format(
                        config.DATASET, epoch_idx))
                torch.save(
                    mix_hidden_layer_3d.state_dict(),
                    'params/param_mix_{}_hidden3d_{}'.format(
                        config.DATASET, epoch_idx))
                torch.save(
                    att_speech_layer.state_dict(),
                    'params/param_mix_{}_attlayer_{}'.format(
                        config.DATASET, epoch_idx))

            # print 'Parameter history:'
            # for pa_gen in [{'params':mix_hidden_layer_3d.parameters()},
            #                                  {'params':mix_speech_multiEmbedding.parameters()},
            #                                  {'params':mix_hidden_layer_3d.parameters()},
            #                                  {'params':att_speech_layer.parameters()},
            #                                  {'params':att_layer.parameters()},
            #                                  {'params':mix_speech_classifier.parameters()},
            #                                  ]:
            #     print pa_gen['params'].next().data.cpu().numpy()[0]

            # continue
            1 / 0
            '''视频刺激 Sepration 部分'''
            # try:
            #     query_video_output,query_video_hidden=query_video_layer(Variable(torch.from_numpy(train_data[4])).cuda())
            # except RuntimeError:
            #     print 'RuntimeError here.'+'#'*30
            #     continue

            query_video_output, query_video_hidden = query_video_layer(
                Variable(torch.from_numpy(train_data[4])).cuda())
            if config.Comm_with_Memory:
                #TODO:query更新这里要再检查一遍,最好改成函数,现在有点丑陋。
                aim_idx_FromVideoQuery = torch.max(query_video_output,
                                                   dim=1)[1]  #返回最大的参数
                aim_spk_batch = [
                    dict_idx2spk[int(idx.data.cpu().numpy())]
                    for idx in aim_idx_FromVideoQuery
                ]
                print 'Query class result:', aim_spk_batch, 'p:', query_video_output.data.cpu(
                ).numpy()

                for idx, aim_spk in enumerate(aim_spk_batch):
                    batch_vector = torch.stack(
                        [memory.get_video_vector(aim_spk)])
                    memory.add_video(aim_spk, query_video_hidden[idx])
                query_video_hidden = query_video_hidden + Variable(
                    batch_vector)
                query_video_hidden = query_video_hidden / torch.sum(
                    query_video_hidden * query_video_hidden, 0)
                y_class = Variable(torch.from_numpy(
                    np.array([
                        dict_spk2idx[spk] for spk in train_data['aim_spkname']
                    ])),
                                   requires_grad=False).cuda()
                print y_class
                loss_video_class = loss_query_class(query_video_output,
                                                    y_class)

            mask = att_layer(mix_speech_hidden,
                             query_video_hidden)  #bs*max_len*fre

            predict_map = mask * Variable(
                torch.from_numpy(train_data['mix_feas'])).cuda()
            y_map = Variable(torch.from_numpy(train_data['aim_fea'])).cuda()
            print 'training abs norm this batch:', torch.abs(
                y_map - predict_map).norm().data.cpu().numpy()
            loss_all = loss_func(predict_map, y_map)
            if 0 and config.Save_param:
                torch.save(query_video_layer.state_dict(),
                           'param_video_layer_19_forS1S5')

            if 0 and epoch_idx < 20:
                loss = loss_video_class
                if epoch_idx % 1 == 0 and batch_idx == config.EPOCH_SIZE - 1:
                    torch.save(query_video_layer.state_dict(),
                               'param_video_layer_19_forS1S5')
            else:
                # loss=loss_all+0.1*loss_video_class
                loss = loss_all
            optimizer.zero_grad()  # clear gradients for next train
            loss.backward(
                retain_graph=True)  # backpropagation, compute gradients
            optimizer.step()  # apply gradients
Ejemplo n.º 17
0
def train(epoch):
    e = epoch
    model.train()
    SDR_SUM = np.array([])

    if config.schedule:
        scheduler.step()
        print("Decaying learning rate to %g" % scheduler.get_lr()[0])
        if config.is_dis:
            scheduler_dis.step()
        lera.log({
            'lr': scheduler.get_lr()[0],
        })

    if opt.model == 'gated':
        model.current_epoch = epoch

    global e, updates, total_loss, start_time, report_total, total_loss_sgm, total_loss_ss
    if config.MLMSE:
        global Var

    train_data_gen = prepare_data('once', 'train')
    # for raw_src, src, src_len, raw_tgt, tgt, tgt_len in trainloader:
    while True:
        try:
            train_data = train_data_gen.next()
            if train_data == False:
                print 'SDR_aver_epoch:', SDR_SUM.mean()
                break  #如果这个epoch的生成器没有数据了,直接进入下一个epoch

            src = Variable(torch.from_numpy(train_data['mix_feas']))
            # raw_tgt = [spk.keys() for spk in train_data['multi_spk_fea_list']]
            raw_tgt = [
                sorted(spk.keys()) for spk in train_data['multi_spk_fea_list']
            ]
            feas_tgt = models.rank_feas(
                raw_tgt,
                train_data['multi_spk_fea_list'])  #这里是目标的图谱,aim_size,len,fre

            # 要保证底下这几个都是longTensor(长整数)
            tgt_max_len = config.MAX_MIX + 2  # with bos and eos.
            tgt = Variable(
                torch.from_numpy(
                    np.array([[0] + [dict_spk2idx[spk] for spk in spks] +
                              (tgt_max_len - len(spks) - 1) *
                              [dict_spk2idx['<EOS>']] for spks in raw_tgt],
                             dtype=np.int))).transpose(0,
                                                       1)  #转换成数字,然后前后加开始和结束符号。
            src_len = Variable(
                torch.LongTensor(config.batch_size).zero_() +
                mix_speech_len).unsqueeze(0)
            tgt_len = Variable(
                torch.LongTensor([
                    len(one_spk)
                    for one_spk in train_data['multi_spk_fea_list']
                ])).unsqueeze(0)
            if use_cuda:
                src = src.cuda().transpose(0, 1)
                tgt = tgt.cuda()
                src_len = src_len.cuda()
                tgt_len = tgt_len.cuda()
                feas_tgt = feas_tgt.cuda()

            model.zero_grad()
            # optim.optimizer.zero_grad()

            # aim_list 就是找到有正经说话人的地方的标号
            aim_list = (tgt[1:-1].transpose(0, 1).contiguous().view(-1) !=
                        dict_spk2idx['<EOS>']).nonzero().squeeze()
            aim_list = aim_list.data.cpu().numpy()

            outputs, targets, multi_mask = model(
                src, src_len, tgt, tgt_len,
                dict_spk2idx)  #这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
            print 'mask size:', multi_mask.size()

            if 1 and len(opt.gpus) > 1:
                sgm_loss, num_total, num_correct = model.module.compute_loss(
                    outputs, targets, opt.memory)
            else:
                sgm_loss, num_total, num_correct = model.compute_loss(
                    outputs, targets, opt.memory)
            print 'loss for SGM,this batch:', sgm_loss.data[0] / num_total

            src = src.transpose(0, 1)
            # expand the raw mixed-features to topk_max channel.
            siz = src.size()
            assert len(siz) == 3
            topk_max = config.MAX_MIX  #最多可能的topk个数
            x_input_map_multi = torch.unsqueeze(src, 1).expand(
                siz[0], topk_max, siz[1],
                siz[2]).contiguous().view(-1, siz[1], siz[2])
            x_input_map_multi = x_input_map_multi[aim_list]
            multi_mask = multi_mask.transpose(0, 1)

            if 1 and len(opt.gpus) > 1:
                if config.MLMSE:
                    Var = model.module.update_var(x_input_map_multi,
                                                  multi_mask, feas_tgt)
                    lera.log_image(u'Var weight',
                                   Var.data.cpu().numpy().reshape(
                                       config.speech_fre, config.speech_fre,
                                       1).repeat(3, 2),
                                   clip=(-1, 1))
                    ss_loss = model.module.separation_loss(
                        x_input_map_multi, multi_mask, feas_tgt, Var)
                else:
                    ss_loss = model.module.separation_loss(
                        x_input_map_multi, multi_mask, feas_tgt)
            else:
                ss_loss = model.separation_loss(x_input_map_multi, multi_mask,
                                                feas_tgt)

            loss = sgm_loss + 5 * ss_loss
            # dis_loss model
            if config.is_dis:
                dis_loss = models.loss.dis_loss(config, topk_max, model_dis,
                                                x_input_map_multi, multi_mask,
                                                feas_tgt, func_dis)
                loss = loss + dis_loss
                # print 'dis_para',model_dis.parameters().next()[0]
                # print 'ss_para',model.parameters().next()[0]

            loss.backward()
            # print 'totallllllllllll loss:',loss
            total_loss_sgm += sgm_loss.data[0]
            total_loss_ss += ss_loss.data[0]
            lera.log({
                'sgm_loss': sgm_loss.data[0],
                'ss_loss': ss_loss.data[0],
                'loss:': loss.data[0],
            })

            if (updates % config.eval_interval) in [
                    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
            ]:
                predicted_maps = multi_mask * x_input_map_multi
                # predicted_maps=Variable(feas_tgt)
                utils.bss_eval(config,
                               predicted_maps,
                               train_data['multi_spk_fea_list'],
                               raw_tgt,
                               train_data,
                               dst='batch_outputjaa')
                del predicted_maps, multi_mask, x_input_map_multi
                # raw_input('wait to continue......')
                sdr_aver_batch = bss_test.cal('batch_outputjaa/')
                lera.log({'SDR sample': sdr_aver_batch})
                SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
                print 'SDR_aver_now:', SDR_SUM.mean()

            total_loss += loss.data[0]
            report_total += num_total
            optim.step()
            if config.is_dis:
                optim_dis.step()

            updates += 1
            if updates % 30 == 0:
                logging(
                    "time: %6.3f, epoch: %3d, updates: %8d, train loss this batch: %6.3f,sgm loss: %6.6f,ss loss: %6.6f\n"
                    % (time.time() - start_time, epoch, updates, loss /
                       num_total, total_loss_sgm / 30.0, total_loss_ss / 30.0))
                total_loss_sgm, total_loss_ss = 0, 0

            # continue

            if 0 or updates % config.eval_interval == 0 and epoch > 1:
                logging(
                    "time: %6.3f, epoch: %3d, updates: %8d, train loss: %6.5f\n"
                    % (time.time() - start_time, epoch, updates,
                       total_loss / report_total))
                print('evaluating after %d updates...\r' % updates)
                # score = eval(epoch)
                for metric in config.metric:
                    scores[metric].append(score[metric])
                    lera.log({
                        'sgm_micro_f1': score[metric],
                    })
                    if metric == 'micro_f1' and score[metric] >= max(
                            scores[metric]):
                        save_model(log_path + 'best_' + metric +
                                   '_checkpoint.pt')
                    if metric == 'hamming_loss' and score[metric] <= min(
                            scores[metric]):
                        save_model(log_path + 'best_' + metric +
                                   '_checkpoint.pt')

                model.train()
                total_loss = 0
                start_time = 0
                report_total = 0

        except RuntimeError, eeee:
            print 'Erros here eeee: ', eeee
            continue
        except Exception, dddd:
            print '\n\n\nRare errors: ', dddd
            continue
Ejemplo n.º 18
0
def eval(epoch,test_or_valid='valid'):
    # config.batch_size=1
    global updates,model
    model.eval()
    # print '\n\n测试的时候请设置config里的batch_size为1!!!please set the batch_size as 1'
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    print(('Test or valid:', test_or_valid))
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX, config.MAX_MIX)
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    # for iii in range(2000):
    while True:
        print(('-' * 30))
        eval_data = next(eval_data_gen)
        if eval_data == False:
            print(('SDR_aver_eval_epoch:', SDR_SUM.mean()))
            print(('SDRi_aver_eval_epoch:', SDRi_SUM.mean()))
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_feas']))

        # raw_tgt = [sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']]
        raw_tgt= eval_data['batch_order']
        feas_tgt = models.rank_feas(raw_tgt, eval_data['multi_spk_fea_list'])  # 这里是目标的图谱

        top_k = len(raw_tgt[0])
        # 要保证底下这几个都是longTensor(长整数)
        # tgt = Variable(torch.from_numpy(np.array([[0]+[dict_spk2idx[spk] for spk in spks]+[dict_spk2idx['<EOS>']] for spks in raw_tgt],dtype=np.int))).transpose(0,1) #转换成数字,然后前后加开始和结束符号。
        tgt = Variable(torch.from_numpy(np.array([[0,1,2,102] for __ in range(config.batch_size)], dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。

        src_len = Variable(torch.LongTensor(config.batch_size).zero_() + mix_speech_len).unsqueeze(0)
        tgt_len = Variable(torch.LongTensor([len(one_spk) for one_spk in eval_data['multi_spk_fea_list']])).unsqueeze(0)
        # tgt_len = Variable(torch.LongTensor(config.batch_size).zero_()+len(eval_data['multi_spk_fea_list'][0])).unsqueeze(0)
        if config.WFM:
            siz = src.size()  # bs,T,F
            assert len(siz) == 3
            # topk_max = config.MAX_MIX  # 最多可能的topk个数
            topk_max = 2  # 最多可能的topk个数
            x_input_map_multi = torch.unsqueeze(src, 1).expand(siz[0], topk_max, siz[1], siz[2]).contiguous().view(-1, siz[1], siz[ 2])  # bs,topk,T,F
            feas_tgt_tmp = feas_tgt.view(siz[0], -1, siz[1], siz[2])

            feas_tgt_square = feas_tgt_tmp * feas_tgt_tmp
            feas_tgt_sum_square = torch.sum(feas_tgt_square, dim=1, keepdim=True).expand(siz[0], topk_max, siz[1], siz[2])
            WFM_mask = feas_tgt_square / (feas_tgt_sum_square + 1e-15)
            feas_tgt = x_input_map_multi.view(siz[0], -1, siz[1], siz[2]).data * WFM_mask  # bs,topk,T,F
            feas_tgt = feas_tgt.view(-1, siz[1], siz[2])  # bs*topk,T,F
            WFM_mask = WFM_mask.cuda()
            del x_input_map_multi

        elif config.PSM:
            siz = src.size()  # bs,T,F
            assert len(siz) == 3
            # topk_max = config.MAX_MIX  # 最多可能的topk个数
            topk_max = 2  # 最多可能的topk个数
            x_input_map_multi = torch.unsqueeze(src, 1).expand(siz[0], topk_max, siz[1], siz[2]).contiguous()  # bs,topk,T,F
            feas_tgt_tmp = feas_tgt.view(siz[0], -1, siz[1], siz[2])

            IRM=feas_tgt_tmp/(x_input_map_multi+1e-15)

            angle_tgt=models.rank_feas(raw_tgt, eval_data['multi_spk_angle_list']).view(siz[0],-1,siz[1],siz[2])
            angle_mix=Variable(torch.from_numpy(np.array(eval_data['mix_angle']))).unsqueeze(1).expand(siz[0], topk_max, siz[1], siz[2]).contiguous()
            ang=np.cos(angle_mix-angle_tgt)
            ang=np.clip(ang,0,None)

            # feas_tgt = x_input_map_multi *np.clip(IRM.numpy()*ang,0,1) # bs,topk,T,F
            # feas_tgt = x_input_map_multi *IRM*ang # bs,topk,T,F
            feas_tgt = feas_tgt.view(siz[0],-1,siz[1],siz[2])*ang # bs,topk,T,F
            feas_tgt = feas_tgt.view(-1, siz[1], siz[2])  # bs*topk,T,F
            del x_input_map_multi

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()
            if config.WFM:
                WFM_mask = WFM_mask.cuda()

        predicted_masks, enc_attn_list = model(src, src_len, tgt, tgt_len,
                                               dict_spk2idx)  # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用

        print('predicted mask size:', predicted_masks.size(),'should be topk,bs,T,F') # topk,bs,T,F
        # try:

        # '''
        # expand the raw mixed-features to topk_max channel.
        src = src.transpose(0, 1)
        siz = src.size()
        assert len(siz) == 3
        # if samples[0][-1] != dict_spk2idx['<EOS>']:
        #     print '*'*40+'\nThe model is far from good. End the evaluation.\n'+'*'*40
        #     break
        topk_max = config.MAX_MIX
        x_input_map_multi = torch.unsqueeze(src, 1).expand(siz[0], topk_max, siz[1], siz[2])

        predicted_masks=predicted_masks.transpose(0, 1)
        # if config.WFM:
        #     feas_tgt = x_input_map_multi.data * WFM_mask

        # 注意,bs是第二维
        assert predicted_masks.shape == x_input_map_multi.shape
        assert predicted_masks.size(0) == config.batch_size

        if 1 and len(opt.gpus) > 1:
            ss_loss,best_pmt = model.module.separation_pit_loss(x_input_map_multi, predicted_masks, feas_tgt, )
        else:
            ss_loss,best_pmt = model.separation_pit_loss(x_input_map_multi, predicted_masks, feas_tgt)
        print(('loss for ss,this batch:', ss_loss.cpu().item()))
        print('best perms for this batch:', best_pmt)
        lera.log({
            'ss_loss_' + test_or_valid: ss_loss.cpu().item(),
        })
        writer.add_scalars('scalar/loss',{'ss_loss_'+test_or_valid:ss_loss.cpu().item()},updates+batch_idx)
        del ss_loss
        if batch_idx>10:
            break

        if False: #this part is to test the checkpoints sequencially.
            batch_idx += 1
            if batch_idx%100==0:
                updates=updates+1000
                opt.restore='/data1/shijing_data/2020-02-14-04:58:17/Transformer_PIT_{}.pt'.format(updates)
                print('loading checkpoint...\n', opt.restore)
                checkpoints = torch.load(opt.restore)
                model.module.load_state_dict(checkpoints['model'])
                break
            continue
        # '''''
        if 0 and batch_idx <= (500 / config.batch_size):  # only the former batches counts the SDR
            predicted_maps = predicted_masks * x_input_map_multi
            predicted_maps = predicted_maps.view(-1,mix_speech_len,speech_fre)
            # predicted_maps=Variable(feas_tgt)
            utils.bss_eval2(config, predicted_maps, eval_data['multi_spk_fea_list'], raw_tgt, eval_data,
                            dst='batch_output_test')
            # utils.bss_eval(config, predicted_maps, eval_data['multi_spk_fea_list'], raw_tgt, eval_data,
            #                 dst='batch_output_test')
            del predicted_maps, predicted_masks, x_input_map_multi
            try:
                sdr_aver_batch, sdri_aver_batch=  bss_test.cal('batch_output_test/')
                SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
                SDRi_SUM = np.append(SDRi_SUM, sdri_aver_batch)
            except(AssertionError):
                print('Errors in calculating the SDR')
            print(('SDR_aver_now:', SDR_SUM.mean()))
            print(('SRi_aver_now:', SDRi_SUM.mean()))
            lera.log({'SDR sample'+test_or_valid: SDR_SUM.mean()})
            lera.log({'SDRi sample'+test_or_valid: SDRi_SUM.mean()})
            writer.add_scalars('scalar/loss',{'SDR_sample_'+test_or_valid:sdr_aver_batch},updates)
            # raw_input('Press any key to continue......')
        elif batch_idx == (200 / config.batch_size) + 1 and SDR_SUM.mean() > best_SDR:  # only record the best SDR once.
            print(('Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean())))
            best_SDR = SDR_SUM.mean()
            # save_model(log_path+'checkpoint_bestSDR{}.pt'.format(best_SDR))

        # '''
        # candidate += [convertToLabels(dict_idx2spk, s, dict_spk2idx['<EOS>']) for s in samples]
        # source += raw_src
        # reference += raw_tgt
        # print(('samples:', samples))
        # print(('can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:], reference[-1 * config.batch_size:])))
        # alignments += [align for align in alignment]
        batch_idx += 1

        result = utils.eval_metrics(reference, candidate, dict_spk2idx, log_path)
        print(('hamming_loss: %.8f | micro_f1: %.4f |recall: %.4f | precision: %.4f'
                   % (result['hamming_loss'], result['micro_f1'], result['micro_recall'], result['micro_precision'], )))
Ejemplo n.º 19
0
def eval_bss(mix_hidden_layer_3d,adjust_layer,mix_speech_classifier,mix_speech_multiEmbedding,att_speech_layer,
             loss_multi_func,dict_spk2idx,dict_idx2spk,num_labels,mix_speech_len,speech_fre):
    for i in [mix_speech_multiEmbedding,adjust_layer,mix_speech_classifier,mix_hidden_layer_3d,att_speech_layer]:
        i.training=False
    print '#' * 40
    eval_data_gen=prepare_data('once','valid')
    SDR_SUM=np.array([])
    while True:
        print '\n\n'
        eval_data=eval_data_gen.next()
        if eval_data==False:
            break #如果这个epoch的生成器没有数据了,直接进入下一个epoch
        '''混合语音len,fre,Emb 3D表示层'''
        mix_speech_hidden,mix_tmp_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(eval_data['mix_feas'])).cuda())
        # 暂时关掉video部分,因为s2 s3 s4 的视频数据不全暂时

        '''Speech self Sepration 语音自分离部分'''
        mix_speech_output=mix_speech_classifier(Variable(torch.from_numpy(eval_data['mix_feas'])).cuda())

        if not test_mode:
            y_spk_list= eval_data['multi_spk_fea_list']
            y_spk_gtruth,y_map_gtruth=multi_label_vector(y_spk_list,dict_spk2idx)
            # 如果训练阶段使用Ground truth的分离结果作为判别
            if not test_mode and config.Ground_truth:
                mix_speech_output=Variable(torch.from_numpy(y_map_gtruth)).cuda()
                if test_all_outputchannel: #把输入的mask改成全1,可以用来测试输出所有的channel
                    mix_speech_output=Variable(torch.ones(config.BATCH_SIZE,num_labels,))
                    y_map_gtruth=np.ones([config.BATCH_SIZE,num_labels])

        if test_mode:
            num_labels=2
            alpha0=-0.5
        else:
            alpha0=0.5
        top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=alpha0,top_k=num_labels) #torch.Float型的
        top_k_mask_idx=[np.where(line==1)[0] for line in top_k_mask_mixspeech.numpy()]
        print 'Predict spk list:',print_spk_name(dict_idx2spk,top_k_mask_idx)
        mix_speech_multiEmbs=mix_speech_multiEmbedding(top_k_mask_mixspeech,top_k_mask_idx) # bs*num_labels(最多混合人个数)×Embedding的大小
        mix_adjust=adjust_layer(mix_tmp_hidden,mix_speech_multiEmbs)
        mix_speech_multiEmbs=mix_adjust+mix_speech_multiEmbs

        assert len(top_k_mask_idx[0])==len(top_k_mask_idx[-1])
        top_k_num=len(top_k_mask_idx[0])

        #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
        #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了 
        mix_speech_hidden_5d=mix_speech_hidden.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
        mix_speech_hidden_5d=mix_speech_hidden_5d.expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre,config.EMBEDDING_SIZE).contiguous()
        mix_speech_hidden_5d_last=mix_speech_hidden_5d.view(-1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
        # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
        # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'dot').cuda()
        att_multi_speech=att_speech_layer(mix_speech_hidden_5d_last,mix_speech_multiEmbs.view(-1,config.EMBEDDING_SIZE))
        # print att_multi_speech.size()
        att_multi_speech=att_multi_speech.view(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre) # bs,num_labels,len,fre这个东西
        # print att_multi_speech.size()
        multi_mask=att_multi_speech
        # top_k_mask_mixspeech_multi=top_k_mask_mixspeech.view(config.BATCH_SIZE,top_k_num,1,1).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre)
        # multi_mask=multi_mask*Variable(top_k_mask_mixspeech_multi).cuda()

        x_input_map=Variable(torch.from_numpy(eval_data['mix_feas'])).cuda()
        # print x_input_map.size()
        x_input_map_multi=x_input_map.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre)
        # predict_multi_map=multi_mask*x_input_map_multi
        predict_multi_map=multi_mask*x_input_map_multi


        y_multi_map=np.zeros([config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre],dtype=np.float32)
        batch_spk_multi_dict=eval_data['multi_spk_fea_list']
        if test_mode:
            for iiii in range(config.BATCH_SIZE):
                y_multi_map[iiii]=np.array(batch_spk_multi_dict[iiii].values())
        else:
            for idx,sample in enumerate(batch_spk_multi_dict):
                y_idx=sorted([dict_spk2idx[spk] for spk in sample.keys()])
                if not test_mode:
                    assert y_idx==list(top_k_mask_idx[idx])
                for jdx,oo in enumerate(y_idx):
                    y_multi_map[idx,jdx]=sample[dict_idx2spk[oo]]
        y_multi_map= Variable(torch.from_numpy(y_multi_map)).cuda()

        loss_multi_speech=loss_multi_func(predict_multi_map,y_multi_map)

        #各通道和为1的loss部分,应该可以更多的带来差异
        y_sum_map=Variable(torch.ones(config.BATCH_SIZE,mix_speech_len,speech_fre)).cuda()
        predict_sum_map=torch.sum(multi_mask,1)
        loss_multi_sum_speech=loss_multi_func(predict_sum_map,y_sum_map)
        # loss_multi_speech=loss_multi_speech #todo:以后可以研究下这个和为1的效果对比一下,暂时直接MSE效果已经很不错了。
        print 'loss 1 eval, losssum eval : ',loss_multi_speech.data.cpu().numpy(),loss_multi_sum_speech.data.cpu().numpy()
        lrs.send('loss mask eval:',loss_multi_speech.data.cpu()[0])
        lrs.send('loss sum eval:',loss_multi_sum_speech.data.cpu()[0])
        loss_multi_speech=loss_multi_speech+0.5*loss_multi_sum_speech
        print 'evaling multi-abs norm this eval batch:',torch.abs(y_multi_map-predict_multi_map).norm().data.cpu().numpy()
        print 'loss:',loss_multi_speech.data.cpu().numpy()
        bss_eval(predict_multi_map,y_multi_map,top_k_mask_idx,dict_idx2spk,eval_data)
        SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output/', 2))
        print 'SDR_aver_now:',SDR_SUM.mean()

    SDR_aver=SDR_SUM.mean()
    print 'SDR_SUM (len:{}) for epoch eval : '.format(SDR_SUM.shape)
    lrs.send('SDR eval aver',SDR_aver)
    print '#'*40
Ejemplo n.º 20
0
def train_recu(epoch):
    global e, updates, total_loss, start_time, report_total, report_correct, total_loss_sgm, total_loss_ss
    e = epoch
    model.train()
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])

    if updates <= config.warmup:  #如果不在warm阶段就正常规划
        pass
    elif config.schedule and scheduler.get_lr()[0] > 5e-7:
        scheduler.step()
        print(("Decaying learning rate to %g" % scheduler.get_lr()[0]))
        lera.log({
            'lr': [group['lr'] for group in optim.optimizer.param_groups][0],
        })

    if opt.model == 'gated':
        model.current_epoch = epoch

    train_data_gen = prepare_data('once', 'train')
    while True:
        if updates <= config.warmup:  # 如果在warm就开始warmup
            tmp_lr = config.learning_rate * min(
                max(updates, 1)**(-0.5),
                max(updates, 1) * (config.warmup**(-1.5)))
            for param_group in optim.optimizer.param_groups:
                param_group['lr'] = tmp_lr
            scheduler.base_lrs = list(
                [group['lr'] for group in optim.optimizer.param_groups])
            if updates % 100 == 0:  #记录一下
                print(updates)
                print("Warmup learning rate to %g" % tmp_lr)
                lera.log({
                    'lr':
                    [group['lr'] for group in optim.optimizer.param_groups][0],
                })

        train_data = next(train_data_gen)
        if train_data == False:
            print(('SDR_aver_epoch:', SDR_SUM.mean()))
            print(('SDRi_aver_epoch:', SDRi_SUM.mean()))
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch

        src = Variable(torch.from_numpy(train_data['mix_feas']))
        # raw_tgt = [spk.keys() for spk in train_data['multi_spk_fea_list']]
        # raw_tgt = [sorted(spk.keys()) for spk in train_data['multi_spk_fea_list']]
        raw_tgt = train_data['batch_order']
        feas_tgt = models.rank_feas(
            raw_tgt,
            train_data['multi_spk_fea_list'])  # 这里是目标的图谱,aim_size,len,fre
        if 0 and config.WFM:
            tmp_size = feas_tgt.size()
            assert len(tmp_size) == 3
            feas_tgt_square = feas_tgt * feas_tgt
            feas_tgt_sum_square = torch.sum(feas_tgt_square,
                                            dim=0,
                                            keepdim=True).expand(tmp_size)
            WFM_mask = feas_tgt_square / (feas_tgt_sum_square + 1e-15)
            WFM_mask = WFM_mask.cuda()
            feas_tgt = x_input_map_multi.data * WFM_mask

        # 要保证底下这几个都是longTensor(长整数)
        src_original = src.transpose(0, 1)  #To T,bs,F
        multi_mask_all = None
        for len_idx in range(config.MIN_MIX + 2, 2, -1):  #逐个分离
            # len_idx=3
            tgt_max_len = len_idx  # 4,3,2 with bos and eos.
            tgt = Variable(
                torch.from_numpy(
                    np.array([[0] + [
                        dict_spk2idx[spk]
                        for spk in spks[-1 * (tgt_max_len - 2):]
                    ] + 1 * [dict_spk2idx['<EOS>']] for spks in raw_tgt],
                             dtype=np.int))).transpose(
                                 0, 1)  # 转换成数字,然后前后加开始和结束符号。4,bs
            src_len = Variable(
                torch.LongTensor(config.batch_size).zero_() +
                mix_speech_len).unsqueeze(0)
            tgt_len = Variable(
                torch.LongTensor([
                    tgt_max_len - 2
                    for one_spk in train_data['multi_spk_fea_list']
                ])).unsqueeze(0)
            if use_cuda:
                src = src.cuda().transpose(0, 1)  # to T,bs,fre
                src_original = src_original.cuda()  # TO T,bs,fre
                tgt = tgt.cuda()
                src_len = src_len.cuda()
                tgt_len = tgt_len.cuda()
                feas_tgt = feas_tgt.cuda()

            model.zero_grad()

            outputs, targets, multi_mask, gamma = model(
                src, src_len, tgt, tgt_len, dict_spk2idx,
                src_original)  # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
            print('mask size:', multi_mask.size())
            # writer.add_histogram('global gamma',gamma, updates)

            if 1 and len(opt.gpus) > 1:
                sgm_loss, num_total, num_correct = model.module.compute_loss(
                    outputs, targets, opt.memory)
            else:
                sgm_loss, num_total, num_correct = model.compute_loss(
                    outputs, targets, opt.memory)
            print(('loss for SGM,this batch:', sgm_loss.cpu().item()))
            writer.add_scalars(
                'scalar/loss',
                {'sgm_loss' + str(len_idx): sgm_loss.cpu().item()}, updates)

            src = src_original.transpose(0, 1)  #确保分离的时候用的是原始的语音
            # expand the raw mixed-features to topk_max channel.
            siz = src.size()  #bs,T,F
            assert len(siz) == 3
            # topk_max = config.MAX_MIX  # 最多可能的topk个数
            topk_max = len_idx - 2  # 最多可能的topk个数
            x_input_map_multi = torch.unsqueeze(src, 1).expand(
                siz[0], topk_max, siz[1],
                siz[2]).contiguous().view(-1, siz[1], siz[2])  #bs,topk,T,F
            # x_input_map_multi = x_input_map_multi[aim_list]
            multi_mask = multi_mask.transpose(0, 1)

            if len_idx == 4:
                aim_feas = list(range(0, 2 * config.batch_size,
                                      2))  #每个samples的第一个说话人取出来
                multi_mask_all = multi_mask  #bs*topk,T,F
                src = src * (1 - multi_mask[aim_feas])  #调整到bs为第一维,# bs,T,F
                # src=src.transpose(0,1)*(1-multi_mask[aim_feas]) #调整到bs为第一维
                src = src.detach()  #第二轮用第一轮预测出来的剩下的谱
            elif len_idx == 3:
                aim_feas = list(range(1, 2 * config.batch_size,
                                      2))  #每个samples的第二个说话人取出来
                multi_mask_all[aim_feas] = multi_mask
                feas_tgt = feas_tgt[aim_feas]
            if 1 and len(opt.gpus) > 1:
                ss_loss = model.module.separation_loss(x_input_map_multi,
                                                       multi_mask, feas_tgt)
            else:
                ss_loss = model.separation_loss(x_input_map_multi, multi_mask,
                                                feas_tgt)
            print(('loss for SS,this batch:', ss_loss.cpu().item()))
            writer.add_scalars(
                'scalar/loss',
                {'ss_loss' + str(len_idx): ss_loss.cpu().item()}, updates)

            loss = sgm_loss + 5 * ss_loss
            loss.backward()
            optim.step()
            lera.log({
                'sgm_loss' + str(len_idx): sgm_loss.cpu().item(),
                'ss_loss' + str(len_idx): ss_loss.cpu().item(),
                'loss:' + str(len_idx): loss.cpu().item(),
            })
            total_loss_sgm += sgm_loss.cpu().item()
            total_loss_ss += ss_loss.cpu().item()

        multi_mask = multi_mask_all
        x_input_map_multi = torch.unsqueeze(src, 1).expand(
            siz[0], 2, siz[1], siz[2]).contiguous().view(-1, siz[1], siz[2])
        if updates > 10 and updates % config.eval_interval in [
                0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
        ]:
            predicted_maps = multi_mask * x_input_map_multi
            # predicted_maps=Variable(feas_tgt)
            utils.bss_eval(config,
                           predicted_maps,
                           train_data['multi_spk_fea_list'],
                           raw_tgt,
                           train_data,
                           dst='batch_output')
            del predicted_maps, multi_mask, x_input_map_multi
            sdr_aver_batch, sdri_aver_batch = bss_test.cal('batch_output/')
            lera.log({'SDR sample': sdr_aver_batch})
            lera.log({'SDRi sample': sdri_aver_batch})
            writer.add_scalars('scalar/loss', {
                'SDR_sample': sdr_aver_batch,
                'SDRi_sample': sdri_aver_batch
            }, updates)
            SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
            SDRi_SUM = np.append(SDRi_SUM, sdri_aver_batch)
            print(('SDR_aver_now:', SDR_SUM.mean()))
            print(('SDRi_aver_now:', SDRi_SUM.mean()))

        total_loss += loss.cpu().item()
        report_correct += num_correct.cpu().item()
        report_total += num_total.cpu().item()

        updates += 1
        if updates % 30 == 0:
            logging(
                "time: %6.3f, epoch: %3d, updates: %8d, train loss this batch: %6.3f,sgm loss: %6.6f,ss loss: %6.6f,label acc: %6.6f\n"
                % (time.time() - start_time, epoch, updates, loss / num_total,
                   total_loss_sgm / 30.0, total_loss_ss / 30.0,
                   report_correct / report_total))
            lera.log({'label_acc': report_correct / report_total})
            writer.add_scalars('scalar/loss',
                               {'label_acc': report_correct / report_total},
                               updates)
            total_loss_sgm, total_loss_ss = 0, 0

        # continue

        if 0 and updates % config.eval_interval == 0 and epoch > 3:  #建议至少跑几个epoch再进行测试,否则模型还没学到东西,会有很多问题。
            logging(
                "time: %6.3f, epoch: %3d, updates: %8d, train loss: %6.5f\n" %
                (time.time() - start_time, epoch, updates,
                 total_loss / report_total))
            print(('evaluating after %d updates...\r' % updates))
            original_bs = config.batch_size
            score = eval(epoch)  # eval的时候batch_size会变成1
            # print 'Orignal bs:',original_bs
            config.batch_size = original_bs
            # print 'Now bs:',config.batch_size
            for metric in config.metric:
                scores[metric].append(score[metric])
                lera.log({
                    'sgm_micro_f1': score[metric],
                })
                if metric == 'micro_f1' and score[metric] >= max(
                        scores[metric]):
                    save_model(log_path + 'best_' + metric + '_checkpoint.pt')
                if metric == 'hamming_loss' and score[metric] <= min(
                        scores[metric]):
                    save_model(log_path + 'best_' + metric + '_checkpoint.pt')

            model.train()
            total_loss = 0
            start_time = 0
            report_total = 0
            report_correct = 0

        if 1 and updates % config.save_interval == 1:
            save_model(log_path + 'TDAAv3_{}.pt'.format(updates))
Ejemplo n.º 21
0
def eval_recu(epoch):
    assert config.batch_size == 1
    model.eval()
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    test_or_valid = 'test'
    test_or_valid = 'valid'
    # test_or_valid = 'train'
    print(('Test or valid:', test_or_valid))
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX,
                                 config.MAX_MIX)
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    while True:
        print(('-' * 30))
        eval_data = next(eval_data_gen)
        if eval_data == False:
            print(('SDR_aver_eval_epoch:', SDR_SUM.mean()))
            print(('SDRi_aver_eval_epoch:', SDRi_SUM.mean()))
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_feas']))

        # raw_tgt = [sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']]
        raw_tgt = eval_data['batch_order']
        feas_tgt = models.rank_feas(
            raw_tgt, eval_data['multi_spk_fea_list'])  # 这里是目标的图谱

        src_original = src.transpose(0, 1)  #To T,bs,F
        predict_multi_mask_all = None
        samples_list = []
        for len_idx in range(config.MIN_MIX + 2, 2, -1):  #逐个分离
            tgt_max_len = len_idx  # 4,3,2 with bos and eos.
            topk_k = len_idx - 2
            tgt = Variable(torch.ones(
                len_idx, config.batch_size))  # 这里随便给一个tgt,为了测试阶段tgt的名字无所谓其实。
            src_len = Variable(
                torch.LongTensor(config.batch_size).zero_() +
                mix_speech_len).unsqueeze(0)
            tgt_len = Variable(
                torch.LongTensor([
                    tgt_max_len - 2
                    for one_spk in eval_data['multi_spk_fea_list']
                ])).unsqueeze(0)
            if use_cuda:
                src = src.cuda().transpose(0, 1)  # to T,bs,fre
                src_original = src_original.cuda()  # TO T,bs,fre
                tgt = tgt.cuda()
                src_len = src_len.cuda()
                tgt_len = tgt_len.cuda()
                feas_tgt = feas_tgt.cuda()

            # try:
            if len(opt.gpus) > 1:
                samples, alignment, hiddens, predicted_masks = model.module.beam_sample(
                    src, src_len, dict_spk2idx, tgt, config.beam_size,
                    src_original)
            else:
                samples, predicted_masks = model.beam_sample(
                    src, src_len, dict_spk2idx, tgt, config.beam_size,
                    src_original)

            # except:
            #     continue

            # '''
            # expand the raw mixed-features to topk_max channel.
            src = src_original.transpose(0, 1)  #确保分离的时候用的是原始的语音
            siz = src.size()
            assert len(siz) == 3
            topk_max = topk_k
            x_input_map_multi = torch.unsqueeze(src, 1).expand(
                siz[0], topk_max, siz[1], siz[2])
            if 0 and config.WFM:
                feas_tgt = x_input_map_multi.data * WFM_mask

            if len_idx == 4:
                aim_feas = list(range(0, 2 * config.batch_size,
                                      2))  #每个samples的第一个说话人取出来
                predict_multi_mask_all = predicted_masks  #bs*topk,T,F
                src = src * (1 - predicted_masks[aim_feas]
                             )  #调整到bs为第一维,# bs,T,F
                samples_list = samples
            elif len_idx == 3:
                aim_feas = list(range(1, 2 * config.batch_size,
                                      2))  #每个samples的第二个说话人取出来
                predict_multi_mask_all[aim_feas] = predicted_masks
                feas_tgt = feas_tgt[aim_feas]
                samples_list = [samples_list[:1] + samples]

            if test_or_valid != 'test':
                if 1 and len(opt.gpus) > 1:
                    ss_loss = model.module.separation_loss(
                        x_input_map_multi,
                        predicted_masks,
                        feas_tgt,
                    )
                else:
                    ss_loss = model.separation_loss(x_input_map_multi,
                                                    predicted_masks, feas_tgt)
                print(('loss for ss,this batch:', ss_loss.cpu().item()))
                lera.log({
                    'ss_loss_' + str(len_idx) + test_or_valid:
                    ss_loss.cpu().item(),
                })
                del ss_loss

        predicted_masks = predict_multi_mask_all
        if batch_idx <= (500 / config.batch_size
                         ):  # only the former batches counts the SDR
            predicted_maps = predicted_masks * x_input_map_multi
            # predicted_maps=Variable(feas_tgt)
            utils.bss_eval2(config,
                            predicted_maps,
                            eval_data['multi_spk_fea_list'],
                            raw_tgt,
                            eval_data,
                            dst='batch_output_test')
            del predicted_maps, predicted_masks, x_input_map_multi
            try:
                sdr_aver_batch, sdri_aver_batch = bss_test.cal(
                    'batch_output_test/')
                SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
                SDRi_SUM = np.append(SDRi_SUM, sdri_aver_batch)
            except (AssertionError):
                print('Errors in calculating the SDR')
            print(('SDR_aver_now:', SDR_SUM.mean()))
            print(('SDRi_aver_now:', SDRi_SUM.mean()))
            lera.log({'SDR sample' + test_or_valid: SDR_SUM.mean()})
            lera.log({'SDRi sample' + test_or_valid: SDRi_SUM.mean()})
            writer.add_scalars('scalar/loss',
                               {'SDR_sample_' + test_or_valid: sdr_aver_batch},
                               updates)
            # raw_input('Press any key to continue......')

        # '''
        candidate += [
            convertToLabels(dict_idx2spk, s, dict_spk2idx['<EOS>'])
            for s in samples_list
        ]
        # source += raw_src
        reference += raw_tgt
        print(('samples:', samples))
        print(('can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:],
                                         reference[-1 * config.batch_size:])))
        # alignments += [align for align in alignment]
        batch_idx += 1
        input('wait to continue......')

        result = utils.eval_metrics(reference, candidate, dict_spk2idx,
                                    log_path)
        print((
            'hamming_loss: %.8f | micro_f1: %.4f |recall: %.4f | precision: %.4f'
            % (
                result['hamming_loss'],
                result['micro_f1'],
                result['micro_recall'],
                result['micro_precision'],
            )))

    score = {}
    result = utils.eval_metrics(reference, candidate, dict_spk2idx, log_path)
    logging_csv([e, updates, result['hamming_loss'], \
                 result['micro_f1'], result['micro_precision'], result['micro_recall'],SDR_SUM.mean()])
    print(('hamming_loss: %.8f | micro_f1: %.4f' %
           (result['hamming_loss'], result['micro_f1'])))
    score['hamming_loss'] = result['hamming_loss']
    score['micro_f1'] = result['micro_f1']
    1 / 0
    return score
Ejemplo n.º 22
0
def main():
    print('go to model')
    print '*' * 80

    spk_global_gen = prepare_data(mode='global',
                                  train_or_test='train')  #写一个假的数据生成,可以用来写模型先
    global_para = spk_global_gen.next()
    print global_para
    spk_all_list, dict_spk2idx, dict_idx2spk, mix_speech_len, speech_fre, total_frames, spk_num_total = global_para
    del spk_global_gen
    num_labels = len(spk_all_list)
    print 'print num_labels:', num_labels

    # data_generator=prepare_data('once','train')
    # data_generator=prepare_data_fake(train_or_test='train',num_labels=num_labels) #写一个假的数据生成,可以用来写模型先

    #此处顺序是 mix_speechs.shape,mix_feas.shape,aim_fea.shape,aim_spkid.shape,query.shape
    #一个例子:(5, 17040) (5, 134, 129) (5, 134, 129) (5,) (5, 32, 400, 300, 3)
    # datasize=prepare_datasize(data_generator)
    # mix_speech_len,speech_fre,total_frames,spk_num_total,video_size=datasize
    print 'Begin to build the maim model for Multi_Modal Cocktail Problem.'
    # data=data_generator.next()

    # This part is to build the 3D mix speech embedding maps.
    mix_hidden_layer_3d = MIX_SPEECH(speech_fre, mix_speech_len).cuda()
    mix_speech_classifier = MIX_SPEECH_classifier(speech_fre, mix_speech_len,
                                                  num_labels).cuda()
    mix_speech_multiEmbedding = SPEECH_EMBEDDING(
        num_labels, config.EMBEDDING_SIZE,
        spk_num_total + config.UNK_SPK_SUPP).cuda()
    print mix_hidden_layer_3d
    print mix_speech_classifier
    # mix_speech_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(data[1])).cuda())

    # mix_speech_output=mix_speech_classifier(Variable(torch.from_numpy(data[1])).cuda())
    # 技巧:alpha0的时候,就是选出top_k,top_k很大的时候,就是选出来大于alpha的
    # top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=config.ALPHA,top_k=config.MAX_MIX)
    # top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=config.ALPHA,top_k=3)
    # print top_k_mask_mixspeech
    # mix_speech_multiEmbs=mix_speech_multiEmbedding(top_k_mask_mixspeech) # bs*num_labels(最多混合人个数)×Embedding的大小
    # mix_speech_multiEmbs=mix_speech_multiEmbedding(Variable(torch.from_numpy(top_k_mask_mixspeech),requires_grad=False).cuda()) # bs*num_labels(最多混合人个数)×Embedding的大小

    # 需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
    # 把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了
    # mix_speech_hidden_5d=mix_speech_hidden.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
    # mix_speech_hidden_5d=mix_speech_hidden_5d.expand(config.BATCH_SIZE,num_labels,mix_speech_len,speech_fre,config.EMBEDDING_SIZE).contiguous()
    # mix_speech_hidden_5d=mix_speech_hidden_5d.view(-1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
    # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
    # att_multi_speech=att_speech_layer(mix_speech_hidden_5d,mix_speech_multiEmbs.view(-1,config.EMBEDDING_SIZE))
    # print att_multi_speech.size()
    # att_multi_speech=att_multi_speech.view(config.BATCH_SIZE,num_labels,mix_speech_len,speech_fre,-1)
    # print att_multi_speech.size()

    # This part is to conduct the video inputs.
    # query_video_layer=VIDEO_QUERY(total_frames,config.VideoSize,spk_num_total).cuda()
    query_video_layer = None
    # print query_video_layer
    # query_video_output,xx=query_video_layer(Variable(torch.from_numpy(data[4])))

    # This part is to conduct the memory.
    # hidden_size=(config.HIDDEN_UNITS)
    hidden_size = (config.EMBEDDING_SIZE)
    memory = MEMORY(spk_num_total + config.UNK_SPK_SUPP, hidden_size)
    memory.register_spklist(spk_all_list)  #把spk_list注册进空的memory里面去

    # Memory function test.
    print 'memory all spkid:', memory.get_all_spkid()
    # print memory.get_image_num('Unknown_id')
    # print memory.get_video_vector('Unknown_id')
    # print memory.add_video('Unknown_id',Variable(torch.ones(300)))

    # This part is to test the ATTENTION methond from query(~) to mix_speech
    # x=torch.arange(0,24).view(2,3,4)
    # y=torch.ones([2,4])
    att_layer = ATTENTION(config.EMBEDDING_SIZE, 'align').cuda()
    att_speech_layer = ATTENTION(config.EMBEDDING_SIZE, 'align').cuda()
    # att=ATTENTION(4,'align')
    # mask=att(x,y)#bs*max_len

    # del data_generator
    # del data

    optimizer = torch.optim.Adam(
        [
            {
                'params': mix_hidden_layer_3d.parameters()
            },
            {
                'params': mix_speech_multiEmbedding.parameters()
            },
            {
                'params': mix_speech_classifier.parameters()
            },
            # {'params':query_video_layer.lstm_layer.parameters()},
            # {'params':query_video_layer.dense.parameters()},
            # {'params':query_video_layer.Linear.parameters()},
            {
                'params': att_layer.parameters()
            },
            {
                'params': att_speech_layer.parameters()
            },
            # ], lr=0.02,momentum=0.9)
        ],
        lr=0.0002)
    if 1 and config.Load_param:
        # query_video_layer.load_state_dict(torch.load('param_video_layer_19'))
        # mix_speech_classifier.load_state_dict(torch.load('params/param_speech_multilabel_epoch249'))
        mix_speech_classifier.load_state_dict(
            torch.load('params/param_speech_WSJ0_multilabel_epoch249'))
        mix_hidden_layer_3d.load_state_dict(
            torch.load('params/param_mix101_WSJ0_hidden3d_180'))
        mix_speech_multiEmbedding.load_state_dict(
            torch.load('params/param_mix101_WSJ0_emblayer_180'))
        att_speech_layer.load_state_dict(
            torch.load('params/param_mix101_WSJ0_attlayer_180'))
    loss_func = torch.nn.MSELoss()  # the target label is NOT an one-hotted
    loss_multi_func = torch.nn.MSELoss(
    )  # the target label is NOT an one-hotted
    # loss_multi_func = torch.nn.L1Loss()  # the target label is NOT an one-hotted
    loss_query_class = torch.nn.CrossEntropyLoss()

    print '''Begin to calculate.'''
    for epoch_idx in range(config.MAX_EPOCH):
        if epoch_idx > 0:
            print 'SDR_SUM (len:{}) for epoch {} : '.format(
                SDR_SUM.shape, epoch_idx - 1, SDR_SUM.mean())
        SDR_SUM = np.array([])
        # print_memory_state(memory.memory)
        print 'SDR_SUM for epoch {}:{}'.format(epoch_idx - 1, SDR_SUM.mean())
        for batch_idx in range(config.EPOCH_SIZE):
            print '*' * 40, epoch_idx, batch_idx, '*' * 40
            # train_data_gen=prepare_data('once','train')
            train_data_gen = prepare_data('once', 'test')
            # train_data_gen=prepare_data('once','eval_test')
            train_data = train_data_gen.next()
            # test_data_gen=prepare_data('once','test')
            # test_data=train_data_gen.next()
            '''混合语音len,fre,Emb 3D表示层'''
            mix_speech_hidden = mix_hidden_layer_3d(
                Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            # 暂时关掉video部分,因为s2 s3 s4 的视频数据不全暂时
            '''Speech self Sepration 语音自分离部分'''
            mix_speech_output = mix_speech_classifier(
                Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            #从数据里得到ground truth的说话人名字和vector
            # y_spk_list=[one.keys() for one in train_data['multi_spk_fea_list']]
            # y_spk_gtruth,y_map_gtruth=multi_label_vector(y_spk_list,dict_spk2idx)
            # 如果训练阶段使用Ground truth的分离结果作为判别
            if 0 and config.Ground_truth:
                mix_speech_output = Variable(
                    torch.from_numpy(y_map_gtruth)).cuda()
                if test_all_outputchannel:  #把输入的mask改成全1,可以用来测试输出所有的channel
                    mix_speech_output = Variable(
                        torch.ones(
                            config.BATCH_SIZE,
                            num_labels,
                        ))
                    y_map_gtruth = np.ones([config.BATCH_SIZE, num_labels])

            # top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=0.5,top_k=num_labels) #torch.Float型的
            max_num_labels = 2
            top_k_mask_mixspeech = top_k_mask(
                mix_speech_output, alpha=-1,
                top_k=max_num_labels)  #torch.Float型的
            mix_speech_multiEmbs = mix_speech_multiEmbedding(
                top_k_mask_mixspeech)  # bs*num_labels(最多混合人个数)×Embedding的大小

            #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
            #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了
            mix_speech_hidden_5d = mix_speech_hidden.view(
                config.BATCH_SIZE, 1, mix_speech_len, speech_fre,
                config.EMBEDDING_SIZE)
            mix_speech_hidden_5d = mix_speech_hidden_5d.expand(
                config.BATCH_SIZE, num_labels, mix_speech_len, speech_fre,
                config.EMBEDDING_SIZE).contiguous()
            mix_speech_hidden_5d_last = mix_speech_hidden_5d.view(
                -1, mix_speech_len, speech_fre, config.EMBEDDING_SIZE)
            # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
            att_speech_layer = ATTENTION(config.EMBEDDING_SIZE, 'dot').cuda()
            att_multi_speech = att_speech_layer(
                mix_speech_hidden_5d_last,
                mix_speech_multiEmbs.view(-1, config.EMBEDDING_SIZE))
            # print att_multi_speech.size()
            att_multi_speech = att_multi_speech.view(
                config.BATCH_SIZE, num_labels, mix_speech_len,
                speech_fre)  # bs,num_labels,len,fre这个东西
            # print att_multi_speech.size()
            multi_mask = att_multi_speech
            top_k_mask_mixspeech_multi = top_k_mask_mixspeech.view(
                config.BATCH_SIZE, num_labels, 1,
                1).expand(config.BATCH_SIZE, num_labels, mix_speech_len,
                          speech_fre)
            multi_mask = multi_mask * Variable(
                top_k_mask_mixspeech_multi).cuda()

            x_input_map = Variable(torch.from_numpy(
                train_data['mix_feas'])).cuda()
            # print x_input_map.size()
            x_input_map_multi = x_input_map.view(
                config.BATCH_SIZE, 1, mix_speech_len,
                speech_fre).expand(config.BATCH_SIZE, num_labels,
                                   mix_speech_len, speech_fre)
            predict_multi_map = multi_mask * x_input_map_multi
            if batch_idx % 100 == 0:
                print multi_mask
            # print predict_multi_map

            bss_eval_fromGenMap(predict_multi_map, top_k_mask_mixspeech,
                                dict_idx2spk, train_data)
            SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output/', 2))

            continue

            optimizer.zero_grad()  # clear gradients for next train
            loss_multi_speech.backward()  # backpropagation, compute gradients
            optimizer.step()  # apply gradients

            if 1 and epoch_idx > 20 and epoch_idx % 10 == 0 and batch_idx == config.EPOCH_SIZE - 1:
                torch.save(
                    mix_speech_multiEmbedding.state_dict(),
                    'params/param_mix_{}_emblayer_{}'.format(
                        config.DATASET, epoch_idx))
                torch.save(
                    mix_hidden_layer_3d.state_dict(),
                    'params/param_mix_{}_hidden3d_{}'.format(
                        config.DATASET, epoch_idx))
                torch.save(
                    att_speech_layer.state_dict(),
                    'params/param_mix_{}_attlayer_{}'.format(
                        config.DATASET, epoch_idx))
Ejemplo n.º 23
0
def eval(epoch):
    # config.batch_size=1
    model.eval()
    print '\n\n测试的时候请设置config里的batch_size为1!!!please set the batch_size as 1'
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    test_or_valid = 'test'
    print('Test or valid:', test_or_valid)
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX,
                                 config.MAX_MIX)
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    while True:
        print('-' * 30)
        eval_data = eval_data_gen.next()
        if eval_data == False:
            print('SDR_aver_eval_epoch:', SDR_SUM.mean())
            print('SDRi_aver_eval_epoch:', SDRi_SUM.mean())
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_feas']))

        raw_tgt = [
            sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']
        ]
        feas_tgt = models.rank_feas(
            raw_tgt, eval_data['multi_spk_fea_list'])  # 这里是目标的图谱

        top_k = len(raw_tgt[0])
        # 要保证底下这几个都是longTensor(长整数)
        # tgt = Variable(torch.from_numpy(np.array([[0]+[dict_spk2idx[spk] for spk in spks]+[dict_spk2idx['<EOS>']] for spks in raw_tgt],dtype=np.int))).transpose(0,1) #转换成数字,然后前后加开始和结束符号。
        tgt = Variable(torch.ones(
            top_k + 2, config.batch_size))  # 这里随便给一个tgt,为了测试阶段tgt的名字无所谓其实。

        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor([
                len(one_spk) for one_spk in eval_data['multi_spk_fea_list']
            ])).unsqueeze(0)
        # tgt_len = Variable(torch.LongTensor(config.batch_size).zero_()+len(eval_data['multi_spk_fea_list'][0])).unsqueeze(0)
        if config.WFM:
            tmp_size = feas_tgt.size()
            assert len(tmp_size) == 4
            feas_tgt_sum = torch.sum(feas_tgt, dim=1, keepdim=True)
            feas_tgt_sum_square = (feas_tgt_sum *
                                   feas_tgt_sum).expand(tmp_size)
            feas_tgt_square = feas_tgt * feas_tgt
            WFM_mask = feas_tgt_square / feas_tgt_sum_square

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()
            if config.WFM:
                WFM_mask = WFM_mask.cuda()

        if 1 and len(opt.gpus) > 1:
            samples, alignment, hiddens, predicted_masks = model.module.beam_sample(
                src, src_len, dict_spk2idx, tgt, beam_size=config.beam_size)
        else:
            samples, alignment, hiddens, predicted_masks = model.beam_sample(
                src, src_len, dict_spk2idx, tgt, beam_size=config.beam_size)

        # '''
        # expand the raw mixed-features to topk_max channel.
        src = src.transpose(0, 1)
        siz = src.size()
        assert len(siz) == 3
        # if samples[0][-1] != dict_spk2idx['<EOS>']:
        #     print '*'*40+'\nThe model is far from good. End the evaluation.\n'+'*'*40
        #     break
        topk_max = len(samples[0]) - 1
        x_input_map_multi = torch.unsqueeze(src,
                                            1).expand(siz[0], topk_max, siz[1],
                                                      siz[2])
        if config.WFM:
            feas_tgt = x_input_map_multi.data * WFM_mask

        if 0 and test_or_valid == 'valid':
            if 1 and len(opt.gpus) > 1:
                ss_loss = model.module.separation_loss(
                    x_input_map_multi,
                    predicted_masks,
                    feas_tgt,
                )
            else:
                ss_loss = model.separation_loss(x_input_map_multi,
                                                predicted_masks, feas_tgt)
            print('loss for ss,this batch:', ss_loss.cpu().item())
            lera.log({
                'ss_loss_' + test_or_valid: ss_loss.cpu().item(),
            })
            del ss_loss, hiddens

        # '''''
        if batch_idx <= (500 / config.batch_size
                         ):  # only the former batches counts the SDR
            predicted_maps = predicted_masks * x_input_map_multi
            # predicted_maps=Variable(feas_tgt)
            utils.bss_eval2(config,
                            predicted_maps,
                            eval_data['multi_spk_fea_list'],
                            raw_tgt,
                            eval_data,
                            dst='batch_output')
            del predicted_maps, predicted_masks, x_input_map_multi
            try:
                SDR_SUM, SDRi_SUM = np.append(SDR_SUM,
                                              bss_test.cal('batch_output/'))
            except AssertionError, wrong_info:
                print 'Errors in calculating the SDR', wrong_info
            print('SDR_aver_now:', SDR_SUM.mean())
            print('SDRi_aver_now:', SDRi_SUM.mean())
            lera.log({'SDR sample' + test_or_valid: SDR_SUM.mean()})
            lera.log({'SDRi sample' + test_or_valid: SDRi_SUM.mean()})
            # raw_input('Press any key to continue......')
        elif batch_idx == (500 / config.batch_size) + 1 and SDR_SUM.mean(
        ) > best_SDR:  # only record the best SDR once.
            print('Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean()))
            best_SDR = SDR_SUM.mean()
Ejemplo n.º 24
0
def eval(epoch):
    # config.batch_size=1
    model.eval()
    # print '\n\n测试的时候请设置config里的batch_size为1!!!please set the batch_size as 1'
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    # test_or_valid = 'test'
    test_or_valid = 'valid'
    print('Test or valid:', test_or_valid)
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX,
                                 config.MAX_MIX)
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    while True:
        print('-' * 30)
        eval_data = next(eval_data_gen)
        if eval_data == False:
            print('SDR_aver_eval_epoch:', SDR_SUM.mean())
            print('SDRi_aver_eval_epoch:', SDRi_SUM.mean())
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_feas']))

        raw_tgt = [
            sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']
        ]
        feas_tgt = models.rank_feas(
            raw_tgt, eval_data['multi_spk_fea_list'])  # 这里是目标的图谱
        padded_mixture, mixture_lengths, padded_source = eval_data['tas_zip']
        padded_mixture = torch.from_numpy(padded_mixture).float()
        mixture_lengths = torch.from_numpy(mixture_lengths)
        padded_source = torch.from_numpy(padded_source).float()

        padded_mixture = padded_mixture.cuda().transpose(0, 1)
        mixture_lengths = mixture_lengths.cuda()
        padded_source = padded_source.cuda()

        top_k = len(raw_tgt[0])
        # 要保证底下这几个都是longTensor(长整数)
        # tgt = Variable(torch.from_numpy(np.array([[0]+[dict_spk2idx[spk] for spk in spks]+[dict_spk2idx['<EOS>']] for spks in raw_tgt],dtype=np.int))).transpose(0,1) #转换成数字,然后前后加开始和结束符号。
        tgt = Variable(torch.ones(
            top_k + 2, config.batch_size))  # 这里随便给一个tgt,为了测试阶段tgt的名字无所谓其实。

        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor([
                len(one_spk) for one_spk in eval_data['multi_spk_fea_list']
            ])).unsqueeze(0)
        # tgt_len = Variable(torch.LongTensor(config.batch_size).zero_()+len(eval_data['multi_spk_fea_list'][0])).unsqueeze(0)
        if config.WFM:
            tmp_size = feas_tgt.size()
            #assert len(tmp_size) == 4
            feas_tgt_sum = torch.sum(feas_tgt, dim=1, keepdim=True)
            feas_tgt_sum_square = (feas_tgt_sum *
                                   feas_tgt_sum).expand(tmp_size)
            feas_tgt_square = feas_tgt * feas_tgt
            WFM_mask = feas_tgt_square / feas_tgt_sum_square

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()
            if config.WFM:
                WFM_mask = WFM_mask.cuda()

        if 1 and len(opt.gpus) > 1:
            samples, alignment, hiddens, predicted_masks = model.module.beam_sample(
                src, src_len, dict_spk2idx, tgt, config.beam_size,
                padded_mixture)
        else:
            samples, alignment, hiddens, predicted_masks = model.beam_sample(
                src, src_len, dict_spk2idx, tgt, config.beam_size,
                padded_mixture)

        # '''
        # expand the raw mixed-features to topk_max channel.
        src = src.transpose(0, 1)
        siz = src.size()
        assert len(siz) == 3
        # if samples[0][-1] != dict_spk2idx['<EOS>']:
        #     print '*'*40+'\nThe model is far from good. End the evaluation.\n'+'*'*40
        #     break
        topk_max = len(samples[0]) - 1
        x_input_map_multi = torch.unsqueeze(src,
                                            1).expand(siz[0], topk_max, siz[1],
                                                      siz[2])
        if config.WFM:
            feas_tgt = x_input_map_multi.data * WFM_mask

        if not config.use_tas and test_or_valid == 'valid':
            if 1 and len(opt.gpus) > 1:
                ss_loss = model.module.separation_loss(
                    x_input_map_multi,
                    predicted_masks,
                    feas_tgt,
                )
            else:
                ss_loss = model.separation_loss(x_input_map_multi,
                                                predicted_masks, feas_tgt)
            print('loss for ss,this batch:', ss_loss.cpu().item())
            # lera.log({
            #     'ss_loss_' + test_or_valid: ss_loss.cpu().item(),
            # })
            del ss_loss, hiddens

        # '''''
        if batch_idx <= (100 / config.batch_size
                         ):  # only the former batches counts the SDR
            if config.use_tas:
                utils.bss_eval_tas(config,
                                   predicted_masks,
                                   eval_data['multi_spk_fea_list'],
                                   raw_tgt,
                                   eval_data,
                                   dst='batch_output1')
            else:
                predicted_maps = predicted_masks * x_input_map_multi
                utils.bss_eval2(config,
                                predicted_maps,
                                eval_data['multi_spk_fea_list'],
                                raw_tgt,
                                eval_data,
                                dst='batch_output1')
                del predicted_maps
            del predicted_masks, x_input_map_multi
            try:
                #SDR_SUM,SDRi_SUM = np.append(SDR_SUM, bss_test.cal('batch_output1/'))
                sdr_aver_batch, sdri_aver_batch = bss_test.cal(
                    'batch_output1/')
                SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
                SDRi_SUM = np.append(SDRi_SUM, sdri_aver_batch)
            except:  # AssertionError,wrong_info:
                print('Errors in calculating the SDR', wrong_info)
            print('SDR_aver_now:', SDR_SUM.mean())
            print('SDRi_aver_now:', SDRi_SUM.mean())
            # lera.log({'SDR sample'+test_or_valid: SDR_SUM.mean()})
            # lera.log({'SDRi sample'+test_or_valid: SDRi_SUM.mean()})
            # raw_input('Press any key to continue......')
        elif batch_idx == (500 / config.batch_size) + 1 and SDR_SUM.mean(
        ) > best_SDR:  # only record the best SDR once.
            print('Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean()))
            best_SDR = SDR_SUM.mean()
            # save_model(log_path+'checkpoint_bestSDR{}.pt'.format(best_SDR))

        # '''
        candidate += [
            convertToLabels(dict_idx2spk, s, dict_spk2idx['<EOS>'])
            for s in samples
        ]
        # source += raw_src
        reference += raw_tgt
        print('samples:', samples)
        print('can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:],
                                        reference[-1 * config.batch_size:]))
        alignments += [align for align in alignment]
        batch_idx += 1

    score = {}
    result = utils.eval_metrics(reference, candidate, dict_spk2idx, log_path)
    logging_csv([e, updates, result['hamming_loss'], \
                 result['micro_f1'], result['micro_precision'], result['micro_recall']])
    print('hamming_loss: %.8f | micro_f1: %.4f' %
          (result['hamming_loss'], result['micro_f1']))
    score['hamming_loss'] = result['hamming_loss']
    score['micro_f1'] = result['micro_f1']
    return score
Ejemplo n.º 25
0
def main():
    print('go to model')
    print '*' * 80

    spk_global_gen=prepare_data(mode='global',train_or_test='train') #写一个假的数据生成,可以用来写模型先
    global_para=spk_global_gen.next()
    print global_para
    spk_all_list,dict_spk2idx,dict_idx2spk,mix_speech_len,speech_fre,total_frames,spk_num_total=global_para
    del spk_global_gen
    num_labels=len(spk_all_list)

    # data_generator=prepare_data('once','train')
    # data_generator=prepare_data_fake(train_or_test='train',num_labels=num_labels) #写一个假的数据生成,可以用来写模型先

    #此处顺序是 mix_speechs.shape,mix_feas.shape,aim_fea.shape,aim_spkid.shape,query.shape
    #一个例子:(5, 17040) (5, 134, 129) (5, 134, 129) (5,) (5, 32, 400, 300, 3)
    # datasize=prepare_datasize(data_generator)
    # mix_speech_len,speech_fre,total_frames,spk_num_total,video_size=datasize
    print 'Begin to build the maim model for Multi_Modal Cocktail Problem.'
    # data=data_generator.next()

    # This part is to build the 3D mix speech embedding maps.
    mix_hidden_layer_3d=MIX_SPEECH(speech_fre,mix_speech_len).cuda()
    mix_speech_classifier=MIX_SPEECH_classifier(speech_fre,mix_speech_len,num_labels).cuda()
    mix_speech_multiEmbedding=SPEECH_EMBEDDING(num_labels,config.EMBEDDING_SIZE,spk_num_total+config.UNK_SPK_SUPP).cuda()
    print mix_hidden_layer_3d
    print mix_speech_classifier
    # mix_speech_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(data[1])).cuda())

    # mix_speech_output=mix_speech_classifier(Variable(torch.from_numpy(data[1])).cuda())
    # 技巧:alpha0的时候,就是选出top_k,top_k很大的时候,就是选出来大于alpha的
    # top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=config.ALPHA,top_k=config.MAX_MIX)
    # top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=config.ALPHA,top_k=3)
    # print top_k_mask_mixspeech
    # mix_speech_multiEmbs=mix_speech_multiEmbedding(top_k_mask_mixspeech) # bs*num_labels(最多混合人个数)×Embedding的大小
    # mix_speech_multiEmbs=mix_speech_multiEmbedding(Variable(torch.from_numpy(top_k_mask_mixspeech),requires_grad=False).cuda()) # bs*num_labels(最多混合人个数)×Embedding的大小

    # 需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
    # 把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了 
    # mix_speech_hidden_5d=mix_speech_hidden.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
    # mix_speech_hidden_5d=mix_speech_hidden_5d.expand(config.BATCH_SIZE,num_labels,mix_speech_len,speech_fre,config.EMBEDDING_SIZE).contiguous()
    # mix_speech_hidden_5d=mix_speech_hidden_5d.view(-1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
    # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
    # att_multi_speech=att_speech_layer(mix_speech_hidden_5d,mix_speech_multiEmbs.view(-1,config.EMBEDDING_SIZE))
    # print att_multi_speech.size()
    # att_multi_speech=att_multi_speech.view(config.BATCH_SIZE,num_labels,mix_speech_len,speech_fre,-1)
    # print att_multi_speech.size()


    # This part is to conduct the video inputs.
    # query_video_layer=VIDEO_QUERY(total_frames,config.VideoSize,spk_num_total).cuda()
    query_video_layer=None
    # print query_video_layer
    # query_video_output,xx=query_video_layer(Variable(torch.from_numpy(data[4])))

    # This part is to conduct the memory.
    # hidden_size=(config.HIDDEN_UNITS)
    hidden_size=(config.EMBEDDING_SIZE)
    # x=torch.arange(0,24).view(2,3,4)
    # y=torch.ones([2,4])
    att_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
    att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
    # att=ATTENTION(4,'align')
    # mask=att(x,y)#bs*max_len

    # del data_generator
    # del data

    optimizer = torch.optim.Adam([{'params':mix_hidden_layer_3d.parameters()},
                                 {'params':mix_speech_multiEmbedding.parameters()},
                                 {'params':mix_speech_classifier.parameters()},
                                 # {'params':query_video_layer.lstm_layer.parameters()},
                                 # {'params':query_video_layer.dense.parameters()},
                                 # {'params':query_video_layer.Linear.parameters()},
                                 {'params':att_layer.parameters()},
                                 {'params':att_speech_layer.parameters()},
                                 # ], lr=0.02,momentum=0.9)
                                 ], lr=0.0002)
    if 0 and config.Load_param:
        # query_video_layer.load_state_dict(torch.load('param_video_layer_19'))
        # mix_speech_classifier.load_state_dict(torch.load('params/param_speech_multilabel_epoch249'))
        mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_WSJ0_hidden3d_180'))
        mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_WSJ0_emblayer_180'))
        att_speech_layer.load_state_dict(torch.load('params/param_mix101_WSJ0_attlayer_180'))
    loss_func = torch.nn.MSELoss()  # the target label is NOT an one-hotted
    loss_multi_func = torch.nn.MSELoss()  # the target label is NOT an one-hotted
    # loss_multi_func = torch.nn.L1Loss()  # the target label is NOT an one-hotted
    loss_query_class=torch.nn.CrossEntropyLoss()

    print '''Begin to calculate.'''
    for epoch_idx in range(config.MAX_EPOCH):
        if epoch_idx%50==0:
            for ee in optimizer.param_groups:
                ee['lr']/=2
        if epoch_idx>0:
            print 'SDR_SUM (len:{}) for epoch {} : '.format(SDR_SUM.shape,epoch_idx-1,SDR_SUM.mean())
        SDR_SUM=np.array([])
        # print_memory_state(memory.memory)
        print 'SDR_SUM for epoch {}:{}'.format(epoch_idx - 1, SDR_SUM.mean())
        for batch_idx in range(config.EPOCH_SIZE):
            print '*' * 40,epoch_idx,batch_idx,'*'*40
            train_data_gen=prepare_data('once','train')
            # train_data_gen=prepare_data('once','test')
            train_data=train_data_gen.next()

            '''混合语音len,fre,Emb 3D表示层'''
            mix_speech_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            # 暂时关掉video部分,因为s2 s3 s4 的视频数据不全暂时

            '''Speech self Sepration 语音自分离部分'''
            mix_speech_output=mix_speech_classifier(Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            #从数据里得到ground truth的说话人名字和vector
            # y_spk_list=[one.keys() for one in train_data['multi_spk_fea_list']]
            y_spk_list= train_data['multi_spk_fea_list']
            y_spk_gtruth,y_map_gtruth=multi_label_vector(y_spk_list,dict_spk2idx)
            # 如果训练阶段使用Ground truth的分离结果作为判别
            if config.Ground_truth:
                mix_speech_output=Variable(torch.from_numpy(y_map_gtruth)).cuda()
                if test_all_outputchannel: #把输入的mask改成全1,可以用来测试输出所有的channel
                    mix_speech_output=Variable(torch.ones(config.BATCH_SIZE,num_labels,))
                    y_map_gtruth=np.ones([config.BATCH_SIZE,num_labels])

            top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=0.5,top_k=num_labels) #torch.Float型的
            top_k_mask_idx=[np.where(line==1)[0] for line in top_k_mask_mixspeech.numpy()]
            mix_speech_multiEmbs=mix_speech_multiEmbedding(top_k_mask_mixspeech,top_k_mask_idx) # bs*num_labels(最多混合人个数)×Embedding的大小

            assert len(top_k_mask_idx[0])==len(top_k_mask_idx[-1])
            top_k_num=len(top_k_mask_idx[0])

            #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
            #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了 
            mix_speech_hidden_5d=mix_speech_hidden.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
            mix_speech_hidden_5d=mix_speech_hidden_5d.expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre,config.EMBEDDING_SIZE).contiguous()
            mix_speech_hidden_5d_last=mix_speech_hidden_5d.view(-1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
            # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
            att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'dot').cuda()
            att_multi_speech=att_speech_layer(mix_speech_hidden_5d_last,mix_speech_multiEmbs.view(-1,config.EMBEDDING_SIZE))
            # print att_multi_speech.size()
            att_multi_speech=att_multi_speech.view(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre) # bs,num_labels,len,fre这个东西
            # print att_multi_speech.size()
            multi_mask=att_multi_speech
            # top_k_mask_mixspeech_multi=top_k_mask_mixspeech.view(config.BATCH_SIZE,top_k_num,1,1).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre)
            # multi_mask=multi_mask*Variable(top_k_mask_mixspeech_multi).cuda()

            x_input_map=Variable(torch.from_numpy(train_data['mix_feas'])).cuda()
            # print x_input_map.size()
            x_input_map_multi=x_input_map.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre)
            # predict_multi_map=multi_mask*x_input_map_multi
            predict_multi_map=multi_mask*x_input_map_multi
            if 0 and batch_idx%100==0:
                print multi_mask
            # print predict_multi_map

            y_multi_map=np.zeros([config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre],dtype=np.float32)
            batch_spk_multi_dict=train_data['multi_spk_fea_list']
            for idx,sample in enumerate(batch_spk_multi_dict):
                y_idx=sorted([dict_spk2idx[spk] for spk in sample.keys()])
                assert y_idx==list(top_k_mask_idx[idx])
                for jdx,oo in enumerate(y_idx):
                    y_multi_map[idx,jdx]=sample[dict_idx2spk[oo]]
            y_multi_map= Variable(torch.from_numpy(y_multi_map)).cuda()

            loss_multi_speech=loss_multi_func(predict_multi_map,y_multi_map)

            #各通道和为1的loss部分,应该可以更多的带来差异
            y_sum_map=Variable(torch.ones(config.BATCH_SIZE,mix_speech_len,speech_fre)).cuda()
            predict_sum_map=torch.sum(multi_mask,1)
            loss_multi_sum_speech=loss_multi_func(predict_sum_map,y_sum_map)
            # loss_multi_speech=loss_multi_speech #todo:以后可以研究下这个和为1的效果对比一下,暂时直接MSE效果已经很不错了。
            print 'loss 1, losssum : ',loss_multi_speech.data.cpu().numpy(),loss_multi_sum_speech.data.cpu().numpy()
            loss_multi_speech=loss_multi_speech+0.5*loss_multi_sum_speech
            print 'training multi-abs norm this batch:',torch.abs(y_multi_map-predict_multi_map).norm().data.cpu().numpy()
            print 'loss:',loss_multi_speech.data.cpu().numpy()

            if 1 or batch_idx==config.EPOCH_SIZE-1:
                bss_eval(predict_multi_map,y_multi_map,top_k_mask_idx,dict_idx2spk,train_data)
                SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output/', 2))


            optimizer.zero_grad()   # clear gradients for next train
            loss_multi_speech.backward()         # backpropagation, compute gradients
            optimizer.step()        # apply gradients

            if 1 and epoch_idx>80 and epoch_idx%10==0 and batch_idx==config.EPOCH_SIZE-1:
                torch.save(mix_speech_multiEmbedding.state_dict(),'params/param_mix_{}_emblayer_{}'.format(config.DATASET,epoch_idx))
                torch.save(mix_hidden_layer_3d.state_dict(),'params/param_mix_{}_hidden3d_{}'.format(config.DATASET,epoch_idx))
                torch.save(att_speech_layer.state_dict(),'params/param_mix_{}_attlayer_{}'.format(config.DATASET,epoch_idx))
Ejemplo n.º 26
0
def train(epoch):
    global e, updates, total_loss, start_time, report_total, report_correct, total_loss_sgm, total_loss_ss
    e = epoch
    model.train()
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])

    if updates <= config.warmup:  #如果不在warm阶段就正常规划
        pass
    elif config.schedule and scheduler.get_lr()[0] > 5e-7:
        scheduler.step()
        print(("Decaying learning rate to %g" % scheduler.get_lr()[0]))
        lera.log({
            'lr': [group['lr'] for group in optim.optimizer.param_groups][0],
        })

    if opt.model == 'gated':
        model.current_epoch = epoch

    # train_data_gen = prepare_data('once', 'train')
    train_data_gen = musdb.DB(root="~/MUSDB18/",
                              subsets='train',
                              split='train')
    train_data_gen = batch_generator(
        list(train_data_gen),
        config.batch_size,
    )
    # while 1:
    #     mix,ref=next(train_data_gen)
    #     import soundfile as sf
    #     sf.write('mix.wav',mix[0,0],44100)
    #     sf.write('vocal.wav',ref[0,0,0],44100)
    #     sf.write('drum.wav',ref[0,1,0],44100)
    #     sf.write('bass.wav',ref[0,2,0],44100)
    #     sf.write('other.wav',ref[0,3,0],44100)
    #     pass

    while True:
        if updates <= config.warmup:  # 如果在warm就开始warmup
            tmp_lr = config.learning_rate * min(
                max(updates, 1)**(-0.5),
                max(updates, 1) * (config.warmup**(-1.5)))
            for param_group in optim.optimizer.param_groups:
                param_group['lr'] = tmp_lr
            scheduler.base_lrs = list(
                [group['lr'] for group in optim.optimizer.param_groups])
            if updates % 100 == 0:  #记录一下
                print(updates)
                print("Warmup learning rate to %g" % tmp_lr)
                lera.log({
                    'lr':
                    [group['lr'] for group in optim.optimizer.param_groups][0],
                })

        train_data = next(train_data_gen)
        if train_data == False:
            print(('SDR_aver_epoch:', SDR_SUM.mean()))
            print(('SDRi_aver_epoch:', SDRi_SUM.mean()))
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch

        padded_mixture, mixture_lengths, padded_source = train_data
        # source:bs,2channel,T  target:bs,4(vocals,drums,bass,other),2channel,T
        padded_mixture = torch.from_numpy(padded_mixture).float()
        topk_this_batch = padded_source.shape[1]
        mixture_lengths = torch.from_numpy(mixture_lengths)
        padded_source = torch.from_numpy(padded_source).float()

        # 要保证底下这几个都是longTensor(长整数)

        if use_cuda:
            padded_mixture = padded_mixture.cuda().transpose(0, 1)
            mixture_lengths = mixture_lengths.cuda()
            padded_source = padded_source.cuda()
            # src = src.cuda().transpose(0, 1)
            # tgt = tgt.cuda()
            # src_len = src_len.cuda()
            # tgt_len = tgt_len.cuda()
            # feas_tgt = feas_tgt.cuda()

        if 0 and loss < -5:
            import soundfile as sf
            idx_in_batch = 0
            sf.write(
                str(idx_in_batch) + '_mix.wav',
                padded_mixture.transpose(
                    0, 1).data.cpu().numpy()[idx_in_batch].transpose(), 44100)
            sf.write(
                str(idx_in_batch) + '_ref_vocal.wav',
                padded_source.data.cpu().numpy()[idx_in_batch, 0].transpose(),
                44100)
            sf.write(
                str(idx_in_batch) + '_ref_drum.wav',
                padded_source.data.cpu().numpy()[idx_in_batch, 1].transpose(),
                44100)
            sf.write(
                str(idx_in_batch) + '_ref_bass.wav',
                padded_source.data.cpu().numpy()[idx_in_batch, 2].transpose(),
                44100)
            sf.write(
                str(idx_in_batch) + '_ref_other.wav',
                padded_source.data.cpu().numpy()[idx_in_batch, 3].transpose(),
                44100)

        model.zero_grad()
        outputs, pred, spks_ordre_list, multi_mask, y_map = model(
            None,
            None,
            None,
            None,
            dict_spk2idx,
            None,
            mix_wav=padded_mixture,
            clean_wavs=padded_source.transpose(
                0, 1))  # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
        print('mask size:', multi_mask.size())
        print('y map size:', y_map.size())
        # print('spk order:', spks_ordre_list) # bs,topk
        # writer.add_histogram('global gamma',gamma, updates)
        multi_mask = multi_mask.transpose(0, 1)
        y_map = y_map.transpose(0, 1)
        spks_ordre_list = spks_ordre_list.transpose(0, 1)

        # expand the raw mixed-features to topk_max channel.
        topk_max = topk_this_batch  # 最多可能的topk个数

        if config.greddy_tf and config.add_last_silence:
            multi_mask, silence_channel = torch.split(multi_mask,
                                                      [topk_this_batch, 1],
                                                      dim=1)
            silence_channel = silence_channel[:, 0]
            assert len(padded_source.shape) == 3
            # padded_source = torch.cat([padded_source,torch.zeros(padded_source.size(0),1,padded_source.size(2))],1)
            if 1 and len(opt.gpus) > 1:
                ss_loss_silence = model.module.silence_loss(silence_channel)
            else:
                ss_loss_silence = model.silence_loss(silence_channel)
            print('loss for SS silence,this batch:',
                  ss_loss_silence.cpu().item())
            writer.add_scalars(
                'scalar/loss',
                {'ss_loss_silence': ss_loss_silence.cpu().item()}, updates)
            lera.log({'ss_loss_silence': ss_loss_silence.cpu().item()})
            if torch.isnan(ss_loss_silence):
                ss_loss_silence = 0

        if config.use_tas:
            # print('source',padded_source)
            # print('est', multi_mask)
            if 1 and len(opt.gpus) > 1:
                ss_loss = model.module.separation_tas_sdr_order_loss(
                    padded_mixture.transpose(0, 1), multi_mask, y_map,
                    mixture_lengths)
            else:
                ss_loss = model.separation_tas_sdr_order_loss(
                    padded_mixture, multi_mask, y_map, mixture_lengths)
            # best_pmt=[list(pmt_list[int(mm)].data.cpu().numpy()) for mm in max_snr_idx]

        print('loss for SS,this batch:', ss_loss.cpu().item())
        # print('best perms for this batch:', best_pmt)
        print('greddy perms for this batch:',
              [ii for ii in spks_ordre_list.data.cpu().numpy()])
        writer.add_scalars('scalar/loss', {'ss_loss': ss_loss.cpu().item()},
                           updates)

        loss = ss_loss
        if config.add_last_silence:
            loss = loss + 0.1 * ss_loss_silence
        loss.backward()

        # print 'totallllllllllll loss:',loss
        total_loss_ss += ss_loss.cpu().item()
        lera.log({
            'ss_loss_' + str(topk_this_batch): ss_loss.cpu().item(),
            'loss:': loss.cpu().item(),
            'pre_min': multi_mask.data.cpu().numpy().min(),
            'pre_max': multi_mask.data.cpu().numpy().max(),
        })

        if 1 or loss < -5:
            import soundfile as sf
            idx_in_batch = 0
            y0 = multi_mask.data.cpu().numpy()[idx_in_batch, 0]
            y1 = multi_mask.data.cpu().numpy()[idx_in_batch, 1]
            y2 = multi_mask.data.cpu().numpy()[idx_in_batch, 2]
            y3 = multi_mask.data.cpu().numpy()[idx_in_batch, 3]
            # sf.write(str(idx_in_batch)+'_pre_0.wav',multi_mask.data.cpu().numpy()[idx_in_batch,0].transpose(),44100)
            # sf.write(str(idx_in_batch)+'_pre_1.wav',multi_mask.data.cpu().numpy()[idx_in_batch,1].transpose(),44100)
            # sf.write(str(idx_in_batch)+'_pre_2.wav',multi_mask.data.cpu().numpy()[idx_in_batch,2].transpose(),44100)
            # sf.write(str(idx_in_batch)+'_pre_3.wav',multi_mask.data.cpu().numpy()[idx_in_batch,3].transpose(),44100)
            print('y0 range:', y0.min(), y0.max())
            print('y1 range:', y1.min(), y1.max())
            print('y2 range:', y2.min(), y2.max())
            print('y3 range:', y3.min(), y3.max())
            # input('wait')
            print('*' * 50)

        if 0 and updates > 10 and updates % config.eval_interval in [
                0, 1, 2, 3, 4, 5
        ]:
            utils.bss_eval_tas(config,
                               multi_mask,
                               train_data['multi_spk_fea_list'],
                               raw_tgt,
                               train_data,
                               dst=log_path + '/batch_output1')
            sdr_aver_batch, sdri_aver_batch = bss_test.cal(log_path +
                                                           '/batch_output1/')

            lera.log({'SDR sample': sdr_aver_batch})
            lera.log({'SDRi sample': sdri_aver_batch})
            writer.add_scalars('scalar/loss', {
                'SDR_sample': sdr_aver_batch,
                'SDRi_sample': sdri_aver_batch
            }, updates)
            SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
            SDRi_SUM = np.append(SDRi_SUM, sdri_aver_batch)
            print(('SDR_aver_now:', SDR_SUM.mean()))
            print(('SDRi_aver_now:', SDRi_SUM.mean()))

        total_loss += loss.cpu().item()
        optim.step()

        updates += 1
        if updates % 30 == 0:
            logging(
                "time: %6.3f, epoch: %3d, updates: %8d, train loss this batch: %6.3f,sgm loss: %6.6f,ss loss: %6.6f,label acc: %6.6f\n"
                % (time.time() - start_time, epoch, updates, 0,
                   total_loss_sgm / 30.0, total_loss_ss / 30.0, 0))
            # lera.log({'label_acc':report_correct/report_total})
            # writer.add_scalars('scalar/loss',{'label_acc':report_correct/report_total},updates)
            total_loss_sgm, total_loss_ss = 0, 0

        # continue

        if 0 and updates % config.eval_interval == 0 and epoch > 3:  #建议至少跑几个epoch再进行测试,否则模型还没学到东西,会有很多问题。
            logging(
                "time: %6.3f, epoch: %3d, updates: %8d, train loss: %6.5f\n" %
                (time.time() - start_time, epoch, updates, 0))
            print(('evaluating after %d updates...\r' % updates))
            original_bs = config.batch_size
            score = eval(epoch)  # eval的时候batch_size会变成1
            # print 'Orignal bs:',original_bs
            config.batch_size = original_bs
            # print 'Now bs:',config.batch_size
            for metric in config.metric:
                scores[metric].append(score[metric])
                lera.log({
                    'sgm_micro_f1': score[metric],
                })
                if metric == 'micro_f1' and score[metric] >= max(
                        scores[metric]):
                    save_model(log_path + 'best_' + metric + '_checkpoint.pt')
                if metric == 'hamming_loss' and score[metric] <= min(
                        scores[metric]):
                    save_model(log_path + 'best_' + metric + '_checkpoint.pt')

            model.train()
            total_loss = 0
            start_time = 0
            report_total = 0
            report_correct = 0

        if updates > 10 and updates % config.save_interval == 1:
            save_model(log_path + 'TDAAv4_conditional_{}.pt'.format(updates))
Ejemplo n.º 27
0
                # save_model(log_path+'checkpoint_bestSDR{}.pt'.format(best_SDR))
            print '#' * 30 + 'ReID part ' + '#' * 30

        # '''''
        elif batch_idx <= (5000 / config.batch_size
                           ):  #only the former batches counts the SDR
            predicted_maps = predicted_masks * x_input_map_multi
            # predicted_maps=Variable(feas_tgt)
            utils.bss_eval(config,
                           predicted_maps,
                           eval_data['multi_spk_fea_list'],
                           raw_tgt,
                           eval_data,
                           dst='batch_output23jo')
            del predicted_maps, predicted_masks, x_input_map_multi
            SDR, SDRi = bss_test.cal('batch_output23jo/')
            # SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output23jo/'))
            SDR_SUM = np.append(SDR_SUM, SDR)
            SDRi_SUM = np.append(SDRi_SUM, SDRi)
            print 'SDR_aver_now:', SDR_SUM.mean()
            print 'SDRi_aver_now:', SDRi_SUM.mean()
            lera.log({'SDR sample': SDR_SUM.mean()})
            lera.log({'SDRi sample': SDRi_SUM.mean()})
        elif batch_idx == (5000 / config.batch_size) + 1 and SDR_SUM.mean(
        ) > best_SDR:  #only record the best SDR once.
            print 'Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean())
            best_SDR = SDR_SUM.mean()
            # save_model(log_path+'checkpoint_bestSDR{}.pt'.format(best_SDR))

        # '''
        candidate += [
Ejemplo n.º 28
0
def eval(epoch):
    model.eval()
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    test_or_valid = 'test'
    print 'Test or valid:', test_or_valid
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX,
                                 config.MAX_MIX)
    # for raw_src, src, src_len, raw_tgt, tgt, tgt_len in validloader:
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    while True:
        # for ___ in range(2):
        print '-' * 30
        eval_data = eval_data_gen.next()
        if eval_data == False:
            break  #如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_feas']))

        raw_tgt = [
            sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']
        ]
        top_k = len(raw_tgt[0])
        # 要保证底下这几个都是longTensor(长整数)
        # tgt = Variable(torch.from_numpy(np.array([[0]+[dict_spk2idx[spk] for spk in spks]+[dict_spk2idx['<EOS>']] for spks in raw_tgt],dtype=np.int))).transpose(0,1) #转换成数字,然后前后加开始和结束符号。
        tgt = Variable(torch.ones(
            top_k + 2, config.batch_size))  # 这里随便给一个tgt,为了测试阶段tgt的名字无所谓其实。

        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            len(eval_data['multi_spk_fea_list'][0])).unsqueeze(0)
        feas_tgt = models.rank_feas(raw_tgt,
                                    eval_data['multi_spk_fea_list'])  #这里是目标的图谱
        if config.WFM:
            tmp_size = feas_tgt.size()
            assert len(tmp_size) == 4
            feas_tgt_square = feas_tgt * feas_tgt
            feas_tgt_square_sum = torch.sum(feas_tgt_square,
                                            dim=1,
                                            keepdim=True).expand(tmp_size)
            WFM_mask = feas_tgt_square / (feas_tgt_square_sum + 1e-10)

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()
            if config.WFM:
                WFM_mask = WFM_mask.cuda()
        try:
            if 1 and len(opt.gpus) > 1:
                # samples, alignment = model.module.sample(src, src_len)
                samples, alignment, hiddens, predicted_masks = model.module.beam_sample(
                    src,
                    src_len,
                    dict_spk2idx,
                    tgt,
                    beam_size=config.beam_size)
            else:
                samples, alignment, hiddens, predicted_masks = model.beam_sample(
                    src,
                    src_len,
                    dict_spk2idx,
                    tgt,
                    beam_size=config.beam_size)
                # samples, alignment, hiddens, predicted_masks = model.beam_sample(src, src_len, dict_spk2idx, tgt, beam_size=config.beam_size)
        except Exception, info:
            print '**************Error eval occurs here************:', info
            continue
        if len(samples[0]) != 3:
            print 'Wrong num of mixtures, passed.'
            continue

        if config.top1:
            predicted_masks = torch.cat([predicted_masks, 1 - predicted_masks],
                                        1)

        # '''
        # expand the raw mixed-features to topk channel.
        src = src.transpose(0, 1)
        siz = src.size()
        assert len(siz) == 3
        topk = feas_tgt.size()[1]
        x_input_map_multi = torch.unsqueeze(src,
                                            1).expand(siz[0], topk, siz[1],
                                                      siz[2])
        if config.WFM:
            feas_tgt = x_input_map_multi.data * WFM_mask
        if 1 and len(opt.gpus) > 1:
            ss_loss = model.module.separation_loss(x_input_map_multi,
                                                   predicted_masks, feas_tgt,
                                                   None)
        else:
            ss_loss = model.separation_loss(x_input_map_multi, predicted_masks,
                                            feas_tgt, None)
        print 'loss for ss,this batch:', ss_loss.data[0]
        lera.log({
            'ss_loss_' + test_or_valid: ss_loss.data[0],
        })

        del ss_loss, hiddens
        if 0 and config.reID:
            print '#' * 30 + 'ReID part ' + '#' * 30
            predict_multi_map = predicted_masks * x_input_map_multi
            predict_multi_map = predict_multi_map.view(-1, mix_speech_len,
                                                       speech_fre).transpose(
                                                           0, 1)
            tgt_reID = Variable(torch.ones(
                3, top_k * config.batch_size))  # 这里随便给一个tgt,为了测试阶段tgt的名字无所谓其实。
            src_len_reID = Variable(
                torch.LongTensor(topk * config.batch_size).zero_() +
                mix_speech_len).unsqueeze(0).cuda()
            try:
                if 1 and len(opt.gpus) > 1:
                    # samples, alignment = model.module.sample(src, src_len)
                    samples, alignment, hiddens, predicted_masks = model.module.beam_sample(
                        predict_multi_map,
                        src_len_reID,
                        dict_spk2idx,
                        tgt_reID,
                        beam_size=config.beam_size)
                else:
                    samples, alignment, hiddens, predicted_masks = model.beam_sample(
                        predict_multi_map,
                        src_len_reID,
                        dict_spk2idx,
                        tgt_reID,
                        beam_size=config.beam_size)
                    # samples, alignment, hiddens, predicted_masks = model.beam_sample(src, src_len, dict_spk2idx, tgt, beam_size=config.beam_size)
            except Exception, info:
                print '**************Error eval occurs here************:', info
            # outputs_reID, targets_reID, multi_mask_reID = model(predict_multi_map, src_len_reID, tgt_reID, tgt_len_reID) #这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
            if batch_idx <= (500 / config.batch_size
                             ):  #only the former batches counts the SDR
                # predicted_maps=predicted_masks*x_input_map_multi
                predicted_maps = predicted_masks * predict_multi_map.transpose(
                    0, 1).unsqueeze(1)
                predicted_maps = predicted_maps.transpose(0, 1)
                # predicted_maps=Variable(feas_tgt)
                utils.bss_eval(config,
                               predicted_maps,
                               eval_data['multi_spk_fea_list'],
                               raw_tgt,
                               eval_data,
                               dst='batch_output23jo')
                del predicted_maps, predicted_masks, x_input_map_multi, predict_multi_map
                SDR, SDRi = bss_test.cal('batch_output23jo/')
                # SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output23jo/'))
                SDR_SUM = np.append(SDR_SUM, SDR)
                SDRi_SUM = np.append(SDRi_SUM, SDRi)
                print 'SDR_aver_now:', SDR_SUM.mean()
                print 'SDRi_aver_now:', SDRi_SUM.mean()
                lera.log({'SDR sample': SDR_SUM.mean()})
                lera.log({'SDRi sample': SDRi_SUM.mean()})
            elif batch_idx == (500 / config.batch_size) + 1 and SDR_SUM.mean(
            ) > best_SDR:  #only record the best SDR once.
                print 'Best SDR from {}---->{}'.format(best_SDR,
                                                       SDR_SUM.mean())
                best_SDR = SDR_SUM.mean()
                # save_model(log_path+'checkpoint_bestSDR{}.pt'.format(best_SDR))
            print '#' * 30 + 'ReID part ' + '#' * 30
Ejemplo n.º 29
0
def main():
    print('go to model')
    print '*' * 80

    spk_global_gen = prepare_data(mode='global',
                                  train_or_test='train')  #写一个假的数据生成,可以用来写模型先
    global_para = spk_global_gen.next()
    print global_para
    spk_all_list, dict_spk2idx, dict_idx2spk, mix_speech_len, speech_fre, total_frames, spk_num_total = global_para
    del spk_global_gen
    num_labels = len(spk_all_list)

    # data_generator=prepare_data('once','train')
    # data_generator=prepare_data_fake(train_or_test='train',num_labels=num_labels) #写一个假的数据生成,可以用来写模型先

    #此处顺序是 mix_speechs.shape,mix_feas.shape,aim_fea.shape,aim_spkid.shape,query.shape
    #一个例子:(5, 17040) (5, 134, 129) (5, 134, 129) (5,) (5, 32, 400, 300, 3)
    # datasize=prepare_datasize(data_generator)
    # mix_speech_len,speech_fre,total_frames,spk_num_total,video_size=datasize
    print 'Begin to build the maim model for Multi_Modal Cocktail Problem.'
    # data=data_generator.next()

    # This part is to build the 3D mix speech embedding maps.
    mix_hidden_layer_3d = MIX_SPEECH(speech_fre, mix_speech_len).cuda()
    mix_speech_classifier = MIX_SPEECH_classifier(speech_fre, mix_speech_len,
                                                  num_labels).cuda()
    mix_speech_multiEmbedding = SPEECH_EMBEDDING(
        num_labels, config.EMBEDDING_SIZE,
        spk_num_total + config.UNK_SPK_SUPP).cuda()
    print mix_hidden_layer_3d
    print mix_speech_classifier
    # mix_speech_hidden=mix_hidden_layer_3d(Variable(torch.from_numpy(data[1])).cuda())

    # mix_speech_output=mix_speech_classifier(Variable(torch.from_numpy(data[1])).cuda())
    # 技巧:alpha0的时候,就是选出top_k,top_k很大的时候,就是选出来大于alpha的
    # top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=config.ALPHA,top_k=config.MAX_MIX)
    # top_k_mask_mixspeech=top_k_mask(mix_speech_output,alpha=config.ALPHA,top_k=3)
    # print top_k_mask_mixspeech
    # mix_speech_multiEmbs=mix_speech_multiEmbedding(top_k_mask_mixspeech) # bs*num_labels(最多混合人个数)×Embedding的大小
    # mix_speech_multiEmbs=mix_speech_multiEmbedding(Variable(torch.from_numpy(top_k_mask_mixspeech),requires_grad=False).cuda()) # bs*num_labels(最多混合人个数)×Embedding的大小

    # 需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
    # 把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了
    # mix_speech_hidden_5d=mix_speech_hidden.view(config.BATCH_SIZE,1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
    # mix_speech_hidden_5d=mix_speech_hidden_5d.expand(config.BATCH_SIZE,num_labels,mix_speech_len,speech_fre,config.EMBEDDING_SIZE).contiguous()
    # mix_speech_hidden_5d=mix_speech_hidden_5d.view(-1,mix_speech_len,speech_fre,config.EMBEDDING_SIZE)
    # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
    # att_multi_speech=att_speech_layer(mix_speech_hidden_5d,mix_speech_multiEmbs.view(-1,config.EMBEDDING_SIZE))
    # print att_multi_speech.size()
    # att_multi_speech=att_multi_speech.view(config.BATCH_SIZE,num_labels,mix_speech_len,speech_fre,-1)
    # print att_multi_speech.size()

    # This part is to conduct the video inputs.
    # query_video_layer=VIDEO_QUERY(total_frames,config.VideoSize,spk_num_total).cuda()
    query_video_layer = None
    # print query_video_layer
    # query_video_output,xx=query_video_layer(Variable(torch.from_numpy(data[4])))

    # This part is to conduct the memory.
    # hidden_size=(config.HIDDEN_UNITS)
    hidden_size = (config.EMBEDDING_SIZE)
    # x=torch.arange(0,24).view(2,3,4)
    # y=torch.ones([2,4])
    att_layer = ATTENTION(config.EMBEDDING_SIZE, 'align').cuda()
    att_speech_layer = ATTENTION(config.EMBEDDING_SIZE, 'align').cuda()
    # att=ATTENTION(4,'align')
    # mask=att(x,y)#bs*max_len

    # del data_generator
    # del data

    optimizer = torch.optim.Adam(
        [
            {
                'params': mix_hidden_layer_3d.parameters()
            },
            {
                'params': mix_speech_multiEmbedding.parameters()
            },
            {
                'params': mix_speech_classifier.parameters()
            },
            # {'params':query_video_layer.lstm_layer.parameters()},
            # {'params':query_video_layer.dense.parameters()},
            # {'params':query_video_layer.Linear.parameters()},
            {
                'params': att_layer.parameters()
            },
            {
                'params': att_speech_layer.parameters()
            },
            # ], lr=0.02,momentum=0.9)
        ],
        lr=0.00005)
    if 1 and config.Load_param:
        # query_video_layer.load_state_dict(torch.load('param_video_layer_19'))
        # mix_speech_classifier.load_state_dict(torch.load('params/param_speech_123onezeroag3_WSJ0_multilabel_epoch40'))
        mix_speech_classifier.load_state_dict(
            torch.load(
                'params/param_speech_123onezeroag4_WSJ0_multilabel_epoch70'))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_WSJ0_hidden3d_180'))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_WSJ0_emblayer_180'))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix101_WSJ0_attlayer_180'))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_dbag1nosum_WSJ0_hidden3d_350',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_dbag1nosum_WSJ0_emblayer_350',map_location={'cuda:1':'cuda:0'}))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix101_dbag1nosum_WSJ0_attlayer_350',map_location={'cuda:1':'cuda:0'}))

        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix2or3_db_WSJ0_hidden3d_560',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix2or3_db_WSJ0_emblayer_560',map_location={'cuda:1':'cuda:0'}))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix2or3_db_WSJ0_attlayer_560',map_location={'cuda:1':'cuda:0'}))

        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_dbag2sum_WSJ0_hidden3d_460',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_dbag2sum_WSJ0_emblayer_460',map_location={'cuda:1':'cuda:0'}))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix101_dbag2sum_WSJ0_attlayer_460',map_location={'cuda:1':'cuda:0'}))

        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_dbdropout_WSJ0_hidden3d_370',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_dbdropout_WSJ0_emblayer_370',map_location={'cuda:1':'cuda:0'}))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix101_dbdropout_WSJ0_attlayer_370',map_location={'cuda:1':'cuda:0'}))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_4db_WSJ0_hidden3d_110',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_4db_WSJ0_emblayer_110',map_location={'cuda:1':'cuda:0'}))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix101_4db_WSJ0_attlayer_110',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_classifier.load_state_dict(torch.load('params/param_speech_4lstm_multilabelloss30map_epoch440'))
        # att_speech_layer.load_state_dict(torch.load('params/param_mix101_dbdropoutag_WSJ0_attlayer_220',map_location={'cuda:1':'cuda:0'}))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix101_dbdropoutag_WSJ0_hidden3d_220',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix101_dbdropoutag_WSJ0_emblayer_220',map_location={'cuda:1':'cuda:0'}))

        # att_speech_layer.load_state_dict(torch.load('params/param_mix2_db2dropout_WSJ0_attlayer_495',map_location={'cuda:1':'cuda:0'}))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix2_db2dropout_WSJ0_hidden3d_495',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix2_db2dropout_WSJ0_emblayer_495',map_location={'cuda:1':'cuda:0'}))
        att_speech_layer.load_state_dict(
            torch.load('params/param_mix2_db2dropout_WSJ0_attlayer_90',
                       map_location={'cuda:1': 'cuda:0'}))
        mix_hidden_layer_3d.load_state_dict(
            torch.load('params/param_mix2_db2dropout_WSJ0_hidden3d_90',
                       map_location={'cuda:1': 'cuda:0'}))
        mix_speech_multiEmbedding.load_state_dict(
            torch.load('params/param_mix2_db2dropout_WSJ0_emblayer_90',
                       map_location={'cuda:1': 'cuda:0'}))

        # att_speech_layer.load_state_dict(torch.load('params/param_mix1to3_dbdropoutag1_WSJ0_attlayer_430',map_location={'cuda:1':'cuda:0'}))
        # mix_hidden_layer_3d.load_state_dict(torch.load('params/param_mix1to3_dbdropoutag1_WSJ0_hidden3d_430',map_location={'cuda:1':'cuda:0'}))
        # mix_speech_multiEmbedding.load_state_dict(torch.load('params/param_mix1to3_dbdropoutag1_WSJ0_emblayer_430',map_location={'cuda:1':'cuda:0'}))
    loss_func = torch.nn.MSELoss()  # the target label is NOT an one-hotted
    loss_multi_func = torch.nn.MSELoss(
    )  # the target label is NOT an one-hotted
    # loss_multi_func = torch.nn.L1Loss()  # the target label is NOT an one-hotted
    loss_query_class = torch.nn.CrossEntropyLoss()

    print '''Begin to calculate.'''
    for epoch_idx in range(config.MAX_EPOCH):
        if epoch_idx > 0:
            print 'SDR_SUM (len:{}) for epoch {} : {}'.format(
                SDR_SUM.shape, epoch_idx - 1, SDR_SUM.mean())
        SDR_SUM = np.array([])
        # print_memory_state(memory.memory)
        print 'SDR_SUM for epoch {}:{}'.format(epoch_idx - 1, SDR_SUM.mean())
        for batch_idx in range(config.EPOCH_SIZE):
            print '*' * 40, epoch_idx, batch_idx, '*' * 40
            train_data_gen = prepare_data('once', 'train')
            # train_data_gen=prepare_data('once','test')
            # train_data_gen=prepare_data('once','eval_test')
            train_data = train_data_gen.next()
            '''混合语音len,fre,Emb 3D表示层'''
            mix_speech_hidden = mix_hidden_layer_3d(
                Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            # 暂时关掉video部分,因为s2 s3 s4 的视频数据不全暂时
            '''Speech self Sepration 语音自分离部分'''
            mix_speech_output = mix_speech_classifier(
                Variable(torch.from_numpy(train_data['mix_feas'])).cuda())
            #从数据里得到ground truth的说话人名字和vector
            y_spk_list = [
                one.keys() for one in train_data['multi_spk_fea_list']
            ]
            y_spk_list = train_data['multi_spk_fea_list']
            y_spk_gtruth, y_map_gtruth = multi_label_vector(
                y_spk_list, dict_spk2idx)
            # 如果训练阶段使用Ground truth的分离结果作为判别
            if 1 and config.Ground_truth:
                mix_speech_output = Variable(
                    torch.from_numpy(y_map_gtruth)).cuda()
                if 0 and test_all_outputchannel:  #把输入的mask改成全1,可以用来测试输出所有的channel
                    mix_speech_output = Variable(
                        torch.ones(
                            config.BATCH_SIZE,
                            num_labels,
                        ))
                    y_map_gtruth = np.ones([config.BATCH_SIZE, num_labels])

            max_num_labels = 2
            top_k_mask_mixspeech, top_k_sort_index = top_k_mask(
                mix_speech_output, alpha=-0.5,
                top_k=max_num_labels)  #torch.Float型的
            top_k_mask_idx = [
                np.where(line == 1)[0]
                for line in top_k_mask_mixspeech.numpy()
            ]
            mix_speech_multiEmbs = mix_speech_multiEmbedding(
                top_k_mask_mixspeech,
                top_k_mask_idx)  # bs*num_labels(最多混合人个数)×Embedding的大小

            assert len(top_k_mask_idx[0]) == len(top_k_mask_idx[-1])
            top_k_num = len(top_k_mask_idx[0])

            #需要计算:mix_speech_hidden[bs,len,fre,emb]和mix_mulEmbedding[bs,num_labels,EMB]的Attention
            #把 前者扩充为bs*num_labels,XXXXXXXXX的,后者也是,然后用ATT函数计算它们再转回来就好了
            mix_speech_hidden_5d = mix_speech_hidden.view(
                config.BATCH_SIZE, 1, mix_speech_len, speech_fre,
                config.EMBEDDING_SIZE)
            mix_speech_hidden_5d = mix_speech_hidden_5d.expand(
                config.BATCH_SIZE, top_k_num, mix_speech_len, speech_fre,
                config.EMBEDDING_SIZE).contiguous()
            mix_speech_hidden_5d_last = mix_speech_hidden_5d.view(
                -1, mix_speech_len, speech_fre, config.EMBEDDING_SIZE)
            # att_speech_layer=ATTENTION(config.EMBEDDING_SIZE,'align').cuda()
            att_speech_layer = ATTENTION(config.EMBEDDING_SIZE, 'dot').cuda()
            att_multi_speech = att_speech_layer(
                mix_speech_hidden_5d_last,
                mix_speech_multiEmbs.view(-1, config.EMBEDDING_SIZE))
            # print att_multi_speech.size()
            att_multi_speech = att_multi_speech.view(
                config.BATCH_SIZE, top_k_num, mix_speech_len,
                speech_fre)  # bs,num_labels,len,fre这个东西
            # print att_multi_speech.size()
            multi_mask = att_multi_speech
            # top_k_mask_mixspeech_multi=top_k_mask_mixspeech.view(config.BATCH_SIZE,top_k_num,1,1).expand(config.BATCH_SIZE,top_k_num,mix_speech_len,speech_fre)
            # multi_mask=multi_mask*Variable(top_k_mask_mixspeech_multi).cuda()

            x_input_map = Variable(torch.from_numpy(
                train_data['mix_feas'])).cuda()
            # print x_input_map.size()
            x_input_map_multi = x_input_map.view(
                config.BATCH_SIZE, 1, mix_speech_len,
                speech_fre).expand(config.BATCH_SIZE, top_k_num,
                                   mix_speech_len, speech_fre)
            # predict_multi_map=multi_mask*x_input_map_multi
            predict_multi_map = multi_mask * x_input_map_multi

            bss_eval_fromGenMap(multi_mask, x_input_map, top_k_mask_mixspeech,
                                dict_idx2spk, train_data, top_k_sort_index)
            SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output/', 2))
            print 'SDR_SUM (len:{}) for epoch {} : {}'.format(
                SDR_SUM.shape, epoch_idx, SDR_SUM.mean())
Ejemplo n.º 30
0
def eval(epoch):
    model.eval()
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    test_or_valid = 'valid'
    print 'Test or valid:', test_or_valid
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX,
                                 config.MAX_MIX)
    # for raw_src, src, src_len, raw_tgt, tgt, tgt_len in validloader:
    SDR_SUM = np.array([])
    batch_idx = 0
    global best_SDR
    while True:
        # for ___ in range(2):
        print '-' * 30
        eval_data = eval_data_gen.next()
        if eval_data == False:
            break  #如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_feas']))

        raw_tgt = [
            sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']
        ]
        top_k = len(raw_tgt[0])
        # 要保证底下这几个都是longTensor(长整数)
        # tgt = Variable(torch.from_numpy(np.array([[0]+[dict_spk2idx[spk] for spk in spks]+[dict_spk2idx['<EOS>']] for spks in raw_tgt],dtype=np.int))).transpose(0,1) #转换成数字,然后前后加开始和结束符号。
        tgt = Variable(torch.ones(
            top_k + 2, config.batch_size))  # 这里随便给一个tgt,为了测试阶段tgt的名字无所谓其实。

        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            len(eval_data['multi_spk_fea_list'][0])).unsqueeze(0)
        feas_tgt = models.rank_feas(raw_tgt,
                                    eval_data['multi_spk_fea_list'])  #这里是目标的图谱
        relitu(mix_speech_len, speech_fre, feas_tgt.numpy()[0, 0].transpose())
        relitu(mix_speech_len, speech_fre, feas_tgt.numpy()[0, 1].transpose())
        # 1/0

        if config.WFM:
            tmp_size = feas_tgt.size()
            assert len(tmp_size) == 4
            feas_tgt_sum = torch.sum(feas_tgt, dim=1, keepdim=True)
            feas_tgt_sum_square = (feas_tgt_sum *
                                   feas_tgt_sum).expand(tmp_size)
            feas_tgt_square = feas_tgt * feas_tgt
            WFM_mask = feas_tgt_square / feas_tgt_sum_square

        if use_cuda:
            src = src.cuda()
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()
            if config.WFM:
                WFM_mask = WFM_mask.cuda()
        if len(opt.gpus) > 1:
            samples, alignment = model.module.sample(src, src_len)
        else:
            samples, alignment, hiddens, predicted_masks = model.beam_sample(
                src, src_len, dict_spk2idx, tgt, beam_size=config.beam_size)
            # try:
            #     samples, alignment, hiddens, predicted_masks = model.beam_sample(src, src_len, dict_spk2idx, tgt, beam_size=config.beam_size)
            # except Exception,info:
            #     print '**************Error occurs here************:', info
            #     continue

        if config.top1:
            predicted_masks = torch.cat([predicted_masks, 1 - predicted_masks],
                                        1)

        # '''
        # expand the raw mixed-features to topk channel.
        siz = src.size()
        assert len(siz) == 3
        topk = feas_tgt.size()[1]
        x_input_map_multi = torch.unsqueeze(src,
                                            1).expand(siz[0], topk, siz[1],
                                                      siz[2])
        if config.WFM:
            feas_tgt = x_input_map_multi.data * WFM_mask
        ss_loss = model.separation_loss(x_input_map_multi, predicted_masks,
                                        feas_tgt)
        print 'loss for ss,this batch:', ss_loss.data[0]
        del ss_loss, hiddens

        # '''''
        if batch_idx <= (3000 / config.batch_size
                         ):  #only the former batches counts the SDR
            predicted_maps = predicted_masks * x_input_map_multi
            # predicted_maps=Variable(feas_tgt)
            utils.bss_eval(config,
                           predicted_maps,
                           eval_data['multi_spk_fea_list'],
                           raw_tgt,
                           eval_data,
                           dst='batch_output23jo')
            del predicted_maps, predicted_masks, x_input_map_multi
            SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output23jo/'))
            print 'SDR_aver_now:', SDR_SUM.mean()
            # 1/0
            raw_input('Press any key to continue......')
            continue
        elif batch_idx == (500 / config.batch_size) + 1 and SDR_SUM.mean(
        ) > best_SDR:  #only record the best SDR once.
            print 'Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean())
            best_SDR = SDR_SUM.mean()
            # save_model(log_path+'checkpoint_bestSDR{}.pt'.format(best_SDR))

        # '''
        candidate += [
            convertToLabels(dict_idx2spk, s, dict_spk2idx['<EOS>'])
            for s in samples
        ]
        # source += raw_src
        reference += raw_tgt
        print 'samples:', samples
        print 'can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:],
                                        reference[-1 * config.batch_size:])
        alignments += [align for align in alignment]
        batch_idx += 1

    if opt.unk:
        cands = []
        for s, c, align in zip(source, candidate, alignments):
            cand = []
            for word, idx in zip(c, align):
                if word == dict.UNK_WORD and idx < len(s):
                    try:
                        cand.append(s[idx])
                    except:
                        cand.append(word)
                        print("%d %d\n" % (len(s), idx))
                else:
                    cand.append(word)
            cands.append(cand)
        candidate = cands

    score = {}
    result = utils.eval_metrics(reference, candidate, dict_spk2idx, log_path)
    logging_csv([e, updates, result['hamming_loss'], \
                result['micro_f1'], result['micro_precision'], result['micro_recall']])
    print('hamming_loss: %.8f | micro_f1: %.4f' %
          (result['hamming_loss'], result['micro_f1']))
    score['hamming_loss'] = result['hamming_loss']
    score['micro_f1'] = result['micro_f1']
    return score