Ejemplo n.º 1
0
def eval(epoch):
    model.eval()
    reference, candidate, source, alignments = [], [], [], []
    for raw_src, src, src_len, raw_tgt, tgt, tgt_len in testloader:
        if len(opt.gpus) > 1:
            samples, alignment = model.module.sample(src, src_len)
        else:
            samples, alignment = model.beam_sample(src, src_len, beam_size=config.beam_size)

        candidate += [tgt_vocab.convertToLabels(s, dict.EOS) for s in samples]
        source += raw_src
        reference += raw_tgt
        alignments += [align for align in alignment]

    if opt.unk:
        cands = []
        for s, c, align in zip(source, candidate, alignments):
            cand = []
            for word, idx in zip(c, align):
                if word == dict.UNK_WORD and idx < len(s):
                    try:
                        cand.append(s[idx])
                    except:
                        cand.append(word)
                        print("%d %d\n" % (len(s), idx))
                else:
                    cand.append(word)
            cands.append(cand)
        candidate = cands

    score = {}
    result = utils.eval_metrics(reference, candidate, label_dict, log_path)
    logging_csv([result['hamming_loss'], result['micro_f1'], result['micro_precision'], result['micro_recall']])
    print('hamming_loss: %.8f | micro_f1: %.4f'
          % (result['hamming_loss'], result['micro_f1']))
Ejemplo n.º 2
0
def eval(epoch):
    model.eval()
    reference, candidate, source, alignments = [], [], [], []
    for raw_src, src, src_len, raw_tgt, tgt, tgt_len in validloader:
        if config.beam_size == 1:
            samples = model.sample(src, src_len)
        else:
            samples, alignment = model.beam_sample(src, src_len, beam_size=config.beam_size)

        candidate += [tgt_vocab.convertToLabels(s, dict.EOS) for s in samples]
        source += raw_src
        reference += raw_tgt
        alignments += [align for align in alignment]

        # candidate为预测出来的结果, [[],[],[],[]]
        # 为一个二重列表的形式.
        # 大列表的长度为预测样本的个数.
        # 每一个元素都是一个列表, 为预测出来的样本, 没有<BOS>, <EOS>和<PAD>, 每一个都是真实的单词,不是索引
        # source也是一样

    # 如果预测出unk的话, 用出现次数最多的字符替换它, 一般为True.
    if opt.unk:
        cands = []
        for s, c, align in zip(source, candidate, alignments):
            cand = []
            for word, idx in zip(c, align):
                if word == dict.UNK_WORD and idx < len(s):
                    try:
                        cand.append(s[idx])
                    except:
                        cand.append(word)
                        print("%d %d\n" % (len(s), idx))
                else:
                    cand.append(word)
            cands.append(cand)
        candidate = cands

    with codecs.open(log_path+'candidate.txt','w+','utf-8') as f:
        for i in range(len(candidate)):
            f.write(" ".join(candidate[i])+'\n')

    score = {}
    result = utils.eval_metrics(reference, candidate, label_dict, log_path)
    logging_csv([e, updates, result['hamming_loss'], \
                result['macro_f1'], result['macro_precision'], result['macro_recall'],\
                result['micro_f1'], result['micro_precision'], result['micro_recall']])
    print('hamming_loss: %.8f | macro_f1: %.4f | micro_f1: %.4f'
          % (result['hamming_loss'], result['macro_f1'], result['micro_f1']))
    score['hamming_loss'] = result['hamming_loss']
    score['macro_f1'] = result['macro_f1']
    score['micro_f1'] = result['micro_f1']
    return score
Ejemplo n.º 3
0
def eval():
    model.eval()
    y_true, y_pred = [], []
    for x_list, y in valloader:
        bx, by = [Variable(x).type(torch.LongTensor) for x in x_list], Variable(y)
        if use_cuda:
            bx, by = [x.cuda() for x in bx], by.cuda()
        y_pre = model(bx)
        y_label = torch.max(y_pre, 1)[1].data
        y_true.extend(torch.max(y, 1)[1].tolist())
        y_pred.extend(y_label.tolist())

    score = {}
    result = utils.eval_metrics(y_pred, y_true)
    logging_csv([e, updates, loss.data[0], result['accuracy'], result['f1'], result['precision'], result['recall']])
    print('Epoch: %d | Updates: %d | Train loss: %.4f | Accuracy: %.4f | F1: %.4f | Precision: %.4f | Recall: %.4f'
          % (e, updates, loss.data[0], result['accuracy'], result['f1'], result['precision'], result['recall']))
    score['accuracy'] = result['accuracy']
    score['f1'] = result['f1']

    return score
Ejemplo n.º 4
0
def eval(epoch):
    # config.batch_size=1
    model.eval()
    # print '\n\n测试的时候请设置config里的batch_size为1!!!please set the batch_size as 1'
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    test_or_valid = 'test'
    # test_or_valid = 'valid'
    print(('Test or valid:', test_or_valid))
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX,
                                 config.MAX_MIX)
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    # for iii in range(2000):
    while True:
        print(('-' * 30))
        eval_data = next(eval_data_gen)
        if eval_data == False:
            print(('SDR_aver_eval_epoch:', SDR_SUM.mean()))
            print(('SDRi_aver_eval_epoch:', SDRi_SUM.mean()))
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_feas']))

        # raw_tgt = [sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']]
        raw_tgt = eval_data['batch_order']
        feas_tgt = models.rank_feas(
            raw_tgt, eval_data['multi_spk_fea_list'])  # 这里是目标的图谱

        top_k = len(raw_tgt[0])
        # 要保证底下这几个都是longTensor(长整数)
        # tgt = Variable(torch.from_numpy(np.array([[0]+[dict_spk2idx[spk] for spk in spks]+[dict_spk2idx['<EOS>']] for spks in raw_tgt],dtype=np.int))).transpose(0,1) #转换成数字,然后前后加开始和结束符号。
        # tgt = Variable(torch.from_numpy(np.array([[0,1,2,102] for __ in range(config.batch_size)], dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。
        # tgt = Variable(torch.from_numpy(np.array([[0,1,2,3,102] for __ in range(config.batch_size)], dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。
        tgt = Variable(
            torch.from_numpy(
                np.array([
                    list(range(top_k + 1)) + [102]
                    for __ in range(config.batch_size)
                ],
                         dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。

        padded_mixture, mixture_lengths, padded_source = eval_data['tas_zip']
        padded_mixture = torch.from_numpy(padded_mixture).float()
        mixture_lengths = torch.from_numpy(mixture_lengths)
        padded_source = torch.from_numpy(padded_source).float()

        padded_mixture = padded_mixture.cuda().transpose(0, 1)
        mixture_lengths = mixture_lengths.cuda()
        padded_source = padded_source.cuda()

        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor([
                len(one_spk) for one_spk in eval_data['multi_spk_fea_list']
            ])).unsqueeze(0)
        # tgt_len = Variable(torch.LongTensor(config.batch_size).zero_()+len(eval_data['multi_spk_fea_list'][0])).unsqueeze(0)
        if config.WFM:
            tmp_size = feas_tgt.size()
            assert len(tmp_size) == 3
            feas_tgt_square = feas_tgt * feas_tgt
            feas_tgt_sum_square = torch.sum(feas_tgt_square,
                                            dim=0,
                                            keepdim=True).expand(tmp_size)
            WFM_mask = feas_tgt_square / (feas_tgt_sum_square + 1e-15)

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()
            if config.WFM:
                WFM_mask = WFM_mask.cuda()

        if 1 and len(opt.gpus) > 1:
            outputs, pred, targets, multi_mask, dec_enc_attn_list = model(
                src,
                src_len,
                tgt,
                tgt_len,
                dict_spk2idx,
                None,
                mix_wav=padded_mixture
            )  # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
        else:
            outputs, pred, targets, multi_mask, dec_enc_attn_list = model(
                src,
                src_len,
                tgt,
                tgt_len,
                dict_spk2idx,
                None,
                mix_wav=padded_mixture
            )  # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
        samples = list(
            pred.view(config.batch_size, top_k + 1,
                      -1).max(2)[1].data.cpu().numpy())
        '''

        if 1 and len(opt.gpus) > 1:
            samples,  predicted_masks = model.module.beam_sample(src, src_len, dict_spk2idx, tgt, config.beam_size,None,padded_mixture)
        else:
            samples,  predicted_masks = model.beam_sample(src, src_len, dict_spk2idx, tgt, config.beam_size, None, padded_mixture)
            multi_mask = predicted_masks
            samples=[samples]
        # except:
        #     continue

        # '''
        # expand the raw mixed-features to topk_max channel.
        src = src.transpose(0, 1)
        siz = src.size()
        assert len(siz) == 3
        # if samples[0][-1] != dict_spk2idx['<EOS>']:
        #     print '*'*40+'\nThe model is far from good. End the evaluation.\n'+'*'*40
        #     break
        topk_max = top_k
        x_input_map_multi = torch.unsqueeze(src,
                                            1).expand(siz[0], topk_max, siz[1],
                                                      siz[2])
        multi_mask = multi_mask.transpose(0, 1)

        if test_or_valid != 'test':
            if config.use_tas:
                if 1 and len(opt.gpus) > 1:
                    ss_loss, pmt_list, max_snr_idx, *__ = model.module.separation_tas_loss(
                        padded_mixture, multi_mask, padded_source,
                        mixture_lengths)
                else:
                    ss_loss, pmt_list, max_snr_idx, *__ = model.separation_tas_loss(
                        padded_mixture, multi_mask, padded_source,
                        mixture_lengths)
            print(('loss for ss,this batch:', ss_loss.cpu().item()))
            lera.log({
                'ss_loss_' + test_or_valid: ss_loss.cpu().item(),
            })
            del ss_loss

        # '''''
        if 1 and batch_idx <= (500 / config.batch_size
                               ):  # only the former batches counts the SDR
            utils.bss_eval_tas(config,
                               multi_mask,
                               eval_data['multi_spk_fea_list'],
                               raw_tgt,
                               eval_data,
                               dst=log_path + 'batch_output/')
            del multi_mask, x_input_map_multi
            try:
                sdr_aver_batch, sdri_aver_batch = bss_test.cal(log_path +
                                                               'batch_output/')
                SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
                SDRi_SUM = np.append(SDRi_SUM, sdri_aver_batch)
            except (AssertionError):
                print('Errors in calculating the SDR')
            print(('SDR_aver_now:', SDR_SUM.mean()))
            print(('SDRi_aver_now:', SDRi_SUM.mean()))
            lera.log({'SDR sample' + test_or_valid: SDR_SUM.mean()})
            lera.log({'SDRi sample' + test_or_valid: SDRi_SUM.mean()})
            writer.add_scalars('scalar/loss',
                               {'SDR_sample_' + test_or_valid: sdr_aver_batch},
                               updates)
            # raw_input('Press any key to continue......')
        elif batch_idx == (200 / config.batch_size) + 1 and SDR_SUM.mean(
        ) > best_SDR:  # only record the best SDR once.
            print(('Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean())))
            best_SDR = SDR_SUM.mean()
            # save_model(log_path+'checkpoint_bestSDR{}.pt'.format(best_SDR))
        '''
        import matplotlib.pyplot as plt
        ax = plt.gca()
        ax.invert_yaxis()

        raw_src=models.rank_feas(raw_tgt,eval_data['multi_spk_fea_list'])
        att_idx=0
        att =dec_enc_attn_list.data.cpu().numpy()[:,att_idx] # head,topk,T
        for spk in range(3):
            xx=att[:,spk]
            plt.matshow(xx.reshape(8,1,-1).repeat(50,1).reshape(-1,751), cmap=plt.cm.hot, vmin=0,vmax=0.05)
            plt.colorbar()
            plt.savefig(log_path+'batch_output/'+'spk_{}.png'.format(spk))
            plt.matshow(xx.sum(0).reshape(1, 1, -1).repeat(50, 1).reshape(-1, 751), cmap=plt.cm.hot, vmin=0, vmax=0.05)
            plt.colorbar()
            plt.savefig(log_path + 'batch_output/' + 'spk_sum_{}.png'.format(spk))
        for head in range(8):
            xx=att[head]
            plt.matshow(xx.reshape(3,1,-1).repeat(100,1).reshape(-1,751), cmap=plt.cm.hot, vmin=0,vmax=0.05)
            plt.colorbar()
            plt.savefig(log_path+'batch_output/'+'head_{}.png'.format(head))
        plt.matshow(raw_src[att_idx*2+0].transpose(0,1), cmap=plt.cm.hot, vmin=0,vmax=2)
        plt.colorbar()
        plt.savefig(log_path+'batch_output/'+'source0.png')
        plt.matshow(raw_src[att_idx*2+1].transpose(0,1), cmap=plt.cm.hot, vmin=0,vmax=2)
        plt.colorbar()
        plt.savefig(log_path+'batch_output/'+'source1.png')
        # '''
        candidate += [
            convertToLabels(dict_idx2spk, s, dict_spk2idx['<EOS>'])
            for s in samples
        ]
        # source += raw_src
        reference += raw_tgt
        print(('samples:', samples))
        print(('can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:],
                                         reference[-1 * config.batch_size:])))
        # alignments += [align for align in alignment]
        batch_idx += 1

        result = utils.eval_metrics(reference, candidate, dict_spk2idx,
                                    log_path)
        print((
            'hamming_loss: %.8f | micro_f1: %.4f |recall: %.4f | precision: %.4f'
            % (
                result['hamming_loss'],
                result['micro_f1'],
                result['micro_recall'],
                result['micro_precision'],
            )))

    score = {}
    result = utils.eval_metrics(reference, candidate, dict_spk2idx, log_path)
    logging_csv([e, updates, result['hamming_loss'], \
                 result['micro_f1'], result['micro_precision'], result['micro_recall'],SDR_SUM.mean()])
    print(('hamming_loss: %.8f | micro_f1: %.4f' %
           (result['hamming_loss'], result['micro_f1'])))
    score['hamming_loss'] = result['hamming_loss']
    score['micro_f1'] = result['micro_f1']
    1 / 0
    return score
Ejemplo n.º 5
0
        # '''
        candidate += [
            convertToLabels(dict_idx2spk, s, dict_spk2idx['<EOS>'])
            for s in samples
        ]
        # source += raw_src
        reference += raw_tgt
        print('samples:', samples)
        print('can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:],
                                        reference[-1 * config.batch_size:]))
        alignments += [align for align in alignment]
        batch_idx += 1

    score = {}
    result = utils.eval_metrics(reference, candidate, dict_spk2idx, log_path)
    logging_csv([e, updates, result['hamming_loss'], \
                 result['micro_f1'], result['micro_precision'], result['micro_recall']])
    print('hamming_loss: %.8f | micro_f1: %.4f' %
          (result['hamming_loss'], result['micro_f1']))
    score['hamming_loss'] = result['hamming_loss']
    score['micro_f1'] = result['micro_f1']
    return score


# Convert `idx` to labels. If index `stop` is reached, convert it and return.
def convertToLabels(dict, idx, stop):
    labels = []

    for i in idx:
        i = int(i)
Ejemplo n.º 6
0
def eval_recu(epoch):
    assert config.batch_size == 1
    model.eval()
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    test_or_valid = 'test'
    test_or_valid = 'valid'
    # test_or_valid = 'train'
    print(('Test or valid:', test_or_valid))
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX,
                                 config.MAX_MIX)
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    while True:
        print(('-' * 30))
        eval_data = next(eval_data_gen)
        if eval_data == False:
            print(('SDR_aver_eval_epoch:', SDR_SUM.mean()))
            print(('SDRi_aver_eval_epoch:', SDRi_SUM.mean()))
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_feas']))

        # raw_tgt = [sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']]
        raw_tgt = eval_data['batch_order']
        feas_tgt = models.rank_feas(
            raw_tgt, eval_data['multi_spk_fea_list'])  # 这里是目标的图谱

        src_original = src.transpose(0, 1)  #To T,bs,F
        predict_multi_mask_all = None
        samples_list = []
        for len_idx in range(config.MIN_MIX + 2, 2, -1):  #逐个分离
            tgt_max_len = len_idx  # 4,3,2 with bos and eos.
            topk_k = len_idx - 2
            tgt = Variable(torch.ones(
                len_idx, config.batch_size))  # 这里随便给一个tgt,为了测试阶段tgt的名字无所谓其实。
            src_len = Variable(
                torch.LongTensor(config.batch_size).zero_() +
                mix_speech_len).unsqueeze(0)
            tgt_len = Variable(
                torch.LongTensor([
                    tgt_max_len - 2
                    for one_spk in eval_data['multi_spk_fea_list']
                ])).unsqueeze(0)
            if use_cuda:
                src = src.cuda().transpose(0, 1)  # to T,bs,fre
                src_original = src_original.cuda()  # TO T,bs,fre
                tgt = tgt.cuda()
                src_len = src_len.cuda()
                tgt_len = tgt_len.cuda()
                feas_tgt = feas_tgt.cuda()

            # try:
            if len(opt.gpus) > 1:
                samples, alignment, hiddens, predicted_masks = model.module.beam_sample(
                    src, src_len, dict_spk2idx, tgt, config.beam_size,
                    src_original)
            else:
                samples, predicted_masks = model.beam_sample(
                    src, src_len, dict_spk2idx, tgt, config.beam_size,
                    src_original)

            # except:
            #     continue

            # '''
            # expand the raw mixed-features to topk_max channel.
            src = src_original.transpose(0, 1)  #确保分离的时候用的是原始的语音
            siz = src.size()
            assert len(siz) == 3
            topk_max = topk_k
            x_input_map_multi = torch.unsqueeze(src, 1).expand(
                siz[0], topk_max, siz[1], siz[2])
            if 0 and config.WFM:
                feas_tgt = x_input_map_multi.data * WFM_mask

            if len_idx == 4:
                aim_feas = list(range(0, 2 * config.batch_size,
                                      2))  #每个samples的第一个说话人取出来
                predict_multi_mask_all = predicted_masks  #bs*topk,T,F
                src = src * (1 - predicted_masks[aim_feas]
                             )  #调整到bs为第一维,# bs,T,F
                samples_list = samples
            elif len_idx == 3:
                aim_feas = list(range(1, 2 * config.batch_size,
                                      2))  #每个samples的第二个说话人取出来
                predict_multi_mask_all[aim_feas] = predicted_masks
                feas_tgt = feas_tgt[aim_feas]
                samples_list = [samples_list[:1] + samples]

            if test_or_valid != 'test':
                if 1 and len(opt.gpus) > 1:
                    ss_loss = model.module.separation_loss(
                        x_input_map_multi,
                        predicted_masks,
                        feas_tgt,
                    )
                else:
                    ss_loss = model.separation_loss(x_input_map_multi,
                                                    predicted_masks, feas_tgt)
                print(('loss for ss,this batch:', ss_loss.cpu().item()))
                lera.log({
                    'ss_loss_' + str(len_idx) + test_or_valid:
                    ss_loss.cpu().item(),
                })
                del ss_loss

        predicted_masks = predict_multi_mask_all
        if batch_idx <= (500 / config.batch_size
                         ):  # only the former batches counts the SDR
            predicted_maps = predicted_masks * x_input_map_multi
            # predicted_maps=Variable(feas_tgt)
            utils.bss_eval2(config,
                            predicted_maps,
                            eval_data['multi_spk_fea_list'],
                            raw_tgt,
                            eval_data,
                            dst='batch_output_test')
            del predicted_maps, predicted_masks, x_input_map_multi
            try:
                sdr_aver_batch, sdri_aver_batch = bss_test.cal(
                    'batch_output_test/')
                SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
                SDRi_SUM = np.append(SDRi_SUM, sdri_aver_batch)
            except (AssertionError):
                print('Errors in calculating the SDR')
            print(('SDR_aver_now:', SDR_SUM.mean()))
            print(('SDRi_aver_now:', SDRi_SUM.mean()))
            lera.log({'SDR sample' + test_or_valid: SDR_SUM.mean()})
            lera.log({'SDRi sample' + test_or_valid: SDRi_SUM.mean()})
            writer.add_scalars('scalar/loss',
                               {'SDR_sample_' + test_or_valid: sdr_aver_batch},
                               updates)
            # raw_input('Press any key to continue......')

        # '''
        candidate += [
            convertToLabels(dict_idx2spk, s, dict_spk2idx['<EOS>'])
            for s in samples_list
        ]
        # source += raw_src
        reference += raw_tgt
        print(('samples:', samples))
        print(('can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:],
                                         reference[-1 * config.batch_size:])))
        # alignments += [align for align in alignment]
        batch_idx += 1
        input('wait to continue......')

        result = utils.eval_metrics(reference, candidate, dict_spk2idx,
                                    log_path)
        print((
            'hamming_loss: %.8f | micro_f1: %.4f |recall: %.4f | precision: %.4f'
            % (
                result['hamming_loss'],
                result['micro_f1'],
                result['micro_recall'],
                result['micro_precision'],
            )))

    score = {}
    result = utils.eval_metrics(reference, candidate, dict_spk2idx, log_path)
    logging_csv([e, updates, result['hamming_loss'], \
                 result['micro_f1'], result['micro_precision'], result['micro_recall'],SDR_SUM.mean()])
    print(('hamming_loss: %.8f | micro_f1: %.4f' %
           (result['hamming_loss'], result['micro_f1'])))
    score['hamming_loss'] = result['hamming_loss']
    score['micro_f1'] = result['micro_f1']
    1 / 0
    return score
Ejemplo n.º 7
0
def eval(epoch):
    # config.batch_size=1
    model.eval()
    # print '\n\n测试的时候请设置config里的batch_size为1!!!please set the batch_size as 1'
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    test_or_valid = 'test'
    #test_or_valid = 'valid'
    print('Test or valid:', test_or_valid)
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX,
                                 config.MAX_MIX)
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])
    SISNR_SUM = np.array([])
    SISNRI_SUM = np.array([])
    SS_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    f = open('./results/spk2.txt', 'a')
    f_dir = open('./results/dir2.txt', 'a')
    f_bk = open('./results/spk_bk.txt', 'a')
    f_bk_dir = open('./results/dir_bk.txt', 'a')
    f_emb = open('./results/spk_emb.txt', 'a')
    f_emb_dir = open('./results/dir_emb.txt', 'a')
    f_hidden = open('./results/spk_hidden.txt', 'a')
    f_hidden_dir = open('./results/dir_hidden.txt', 'a')
    while True:
        print('-' * 30)
        eval_data = next(eval_data_gen)
        if eval_data == False:
            print('SDR_aver_eval_epoch:', SDR_SUM.mean())
            print('SDRi_aver_eval_epoch:', SDRi_SUM.mean())
            print('SISNR_aver_eval_epoch:', SISNR_SUM.mean())
            print('SISNRI_aver_eval_epoch:', SISNRI_SUM.mean())
            print('SS_aver_eval_epoch:', SS_SUM.mean())
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch

        raw_tgt = eval_data['batch_order']

        padded_mixture, mixture_lengths, padded_source = eval_data['tas_zip']
        padded_mixture = torch.from_numpy(padded_mixture).float()
        mixture_lengths = torch.from_numpy(mixture_lengths)
        padded_source = torch.from_numpy(padded_source).float()

        padded_mixture = padded_mixture.cuda().transpose(0, 1)
        mixture_lengths = mixture_lengths.cuda()
        padded_source = padded_source.cuda()

        top_k = len(raw_tgt[0])
        tgt = Variable(torch.ones(top_k + 2, config.batch_size))
        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor([
                len(one_spk) for one_spk in eval_data['multi_spk_fea_list']
            ])).unsqueeze(0)

        if use_cuda:
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()

        if 1 and len(opt.gpus) > 1:
            samples, samples_dir, alignment, hiddens, predicted_masks, output_list, output_dir_list, output_bk_list, output_dir_bk_list, hidden_list, hidden_dir_list, emb_list, emb_dir_list = model.module.beam_sample(
                padded_mixture, dict_spk2idx, dict_dir2idx, config.beam_size)
        else:
            samples, samples_dir, alignment, hiddens, predicted_masks, output_list, output_dir_list, output_bk_list, output_dir_bk_list, hidden_list, hidden_dir_list, emb_list, emb_dir_list = model.beam_sample(
                padded_mixture, dict_spk2idx, dict_dir2idx, config.beam_size)

        predicted_masks = predicted_masks.transpose(0, 1)
        predicted_masks = predicted_masks[:, 0:top_k, :]
        mixture = torch.chunk(padded_mixture, 2, dim=-1)
        padded_mixture_c0 = mixture[0].squeeze()

        padded_source1 = padded_source.data.cpu()
        predicted_masks1 = predicted_masks.data.cpu()

        padded_source = padded_source.squeeze().data.cpu().numpy()
        padded_mixture = padded_mixture.squeeze().data.cpu().numpy()
        predicted_masks = predicted_masks.squeeze().data.cpu().numpy()
        padded_mixture_c0 = padded_mixture_c0.squeeze().data.cpu().numpy()
        mixture_lengths = mixture_lengths.cpu()

        predicted_masks = predicted_masks - np.mean(predicted_masks)
        predicted_masks /= np.max(np.abs(predicted_masks))

        # '''''
        if batch_idx <= (3000 / config.batch_size
                         ):  # only the former batches counts the SDR

            sisnr, sisnri = bss_test.cal_SISNRi_PIT(padded_source,
                                                    predicted_masks,
                                                    padded_mixture_c0)
            sdr, sdri = bss_test.cal_SDRi(padded_source, predicted_masks,
                                          padded_mixture_c0)
            loss = ss_tas_loss(config, predicted_masks1, padded_source1,
                               mixture_lengths, True)
            loss = loss.numpy()
            try:
                #SDR_SUM,SDRi_SUM = np.append(SDR_SUM, bss_test.cal('batch_output1/'))
                SDR_SUM = np.append(SDR_SUM, sdr)
                SDRi_SUM = np.append(SDRi_SUM, sdri)

                SISNR_SUM = np.append(SISNR_SUM, sisnr)
                SISNRI_SUM = np.append(SISNRI_SUM, sisnri)
                SS_SUM = np.append(SS_SUM, loss)
            except:  # AssertionError,wrong_info:
                print('Errors in calculating the SDR', wrong_info)
            print('SDR_aver_now:', SDR_SUM.mean())
            print('SDRi_aver_now:', SDRi_SUM.mean())
            print('SISNR_aver_now:', SISNR_SUM.mean())
            print('SISNRI_aver_now:', SISNRI_SUM.mean())
            print('SS_aver_now:', SS_SUM.mean())

        elif batch_idx == (3000 / config.batch_size) + 1 and SDR_SUM.mean(
        ) > best_SDR:  # only record the best SDR once.
            print('Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean()))
            best_SDR = SDR_SUM.mean()

        # '''
        candidate += [
            convertToLabels(dict_idx2spk, s, dict_spk2idx['<EOS>'])
            for s in samples
        ]
        # source += raw_src
        reference += raw_tgt
        print('samples:', samples)
        print('can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:],
                                        reference[-1 * config.batch_size:]))
        alignments += [align for align in alignment]
        batch_idx += 1
    f.close()
    f_dir.close()
    score = {}
    result = utils.eval_metrics(reference, candidate, dict_spk2idx, log_path)
    logging_csv([e, updates, result['hamming_loss'], \
                 result['micro_f1'], result['micro_precision'], result['micro_recall']])
    print('hamming_loss: %.8f | micro_f1: %.4f' %
          (result['hamming_loss'], result['micro_f1']))
    score['hamming_loss'] = result['hamming_loss']
    score['micro_f1'] = result['micro_f1']
    return score
Ejemplo n.º 8
0
def eval(epoch):
    model.eval()
    reference, candidate, source, alignments = [], [], [], []
    for raw_src, src, src_len, raw_tgt, tgt, tgt_len, from_known in testloader:
        if len(opt.gpus) > 1:
            samples, alignment = model.sample(src, src_len)
        else:
            # HINT: 对于beam来说 sample和align的长度相等
            samples, alignment = model.beam_sample(src, src_len, beam_size=config.beam_size)
            # print(samples[:2])

        # print([tgt_vocab.convertToLabels(s, dict.EOS) for s in samples][:2])
        # print(tgt_vocab.idxToLabel)
        # print('here')
        candidate += [tgt_vocab.convertToLabels(s, dict.EOS) for s in samples]
        # print(tgt_vocab.convertToLabels([torch.Tensor(35).long().cuda(), torch.Tensor(3).long().cuda()], dict.EOS))
        # print(candidate[-2:])
        source += raw_src
        reference += raw_tgt
        alignments += [align for align in alignment]

        # for i in range(20, 30):
        #     print(candidate[i])
        #     for align in alignment[i]:
        #         print(raw_src[i][align])

    if opt.unk:
        cands = []
        for s, c, align in zip(source, candidate, alignments):
            cand = []
            for word, idx in zip(c, align):
                if word == dict.UNK_WORD and idx < len(s):
                    try:
                        cand.append(s[idx])
                        # print("replace with {}".format(s[idx]))
                    except:
                        cand.append(word)
                        print("%d %d\n" % (len(s), idx))
                else:
                    cand.append(word)
            cands.append(cand)
        candidate = cands

    score = {}

    # alignment analysis
    with open("alignment_analysis.txt", 'w', encoding='utf-8') as f:
        # convert alignments to human readable words
        # oor_cnt = 0
        # global oor_cnt
        def map_align_to_word_(i):
            def map_align_to_word(align):
                if align < len(source[i]):
                    return source[i][align]
                else:
                    # oor_cnt += 1
                    return 'None'
            return map_align_to_word

        lines = []
        for i in range(len(alignments)):
            line = ""
            align_word = list(map(map_align_to_word_(i), alignments[i]))
            tgt_word = list(zip(candidate[i], align_word))
            for item in tgt_word:
                line += '({}, {}) '.format(item[0], item[1])
            line += '\n'
            lines += line
        f.writelines(lines)



    result = utils.eval_metrics(reference, candidate, label_dict, log_path)
    logging_csv([result['hamming_loss'], result['micro_f1'], result['micro_precision'], result['micro_recall']])
    print('hamming_loss: %.8f | micro_f1: %.4f'
          % (result['hamming_loss'], result['micro_f1']))
Ejemplo n.º 9
0
def eval(epoch):
    # config.batch_size=1
    model.eval()
    # print '\n\n测试的时候请设置config里的batch_size为1!!!please set the batch_size as 1'
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    # test_or_valid = 'test'
    test_or_valid = 'valid'
    print('Test or valid:', test_or_valid)
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX,
                                 config.MAX_MIX)
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    while True:
        print('-' * 30)
        eval_data = next(eval_data_gen)
        if eval_data == False:
            print('SDR_aver_eval_epoch:', SDR_SUM.mean())
            print('SDRi_aver_eval_epoch:', SDRi_SUM.mean())
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_feas']))

        raw_tgt = [
            sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']
        ]
        feas_tgt = models.rank_feas(
            raw_tgt, eval_data['multi_spk_fea_list'])  # 这里是目标的图谱
        padded_mixture, mixture_lengths, padded_source = eval_data['tas_zip']
        padded_mixture = torch.from_numpy(padded_mixture).float()
        mixture_lengths = torch.from_numpy(mixture_lengths)
        padded_source = torch.from_numpy(padded_source).float()

        padded_mixture = padded_mixture.cuda().transpose(0, 1)
        mixture_lengths = mixture_lengths.cuda()
        padded_source = padded_source.cuda()

        top_k = len(raw_tgt[0])
        # 要保证底下这几个都是longTensor(长整数)
        # tgt = Variable(torch.from_numpy(np.array([[0]+[dict_spk2idx[spk] for spk in spks]+[dict_spk2idx['<EOS>']] for spks in raw_tgt],dtype=np.int))).transpose(0,1) #转换成数字,然后前后加开始和结束符号。
        tgt = Variable(torch.ones(
            top_k + 2, config.batch_size))  # 这里随便给一个tgt,为了测试阶段tgt的名字无所谓其实。

        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor([
                len(one_spk) for one_spk in eval_data['multi_spk_fea_list']
            ])).unsqueeze(0)
        # tgt_len = Variable(torch.LongTensor(config.batch_size).zero_()+len(eval_data['multi_spk_fea_list'][0])).unsqueeze(0)
        if config.WFM:
            tmp_size = feas_tgt.size()
            #assert len(tmp_size) == 4
            feas_tgt_sum = torch.sum(feas_tgt, dim=1, keepdim=True)
            feas_tgt_sum_square = (feas_tgt_sum *
                                   feas_tgt_sum).expand(tmp_size)
            feas_tgt_square = feas_tgt * feas_tgt
            WFM_mask = feas_tgt_square / feas_tgt_sum_square

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()
            if config.WFM:
                WFM_mask = WFM_mask.cuda()

        if 1 and len(opt.gpus) > 1:
            samples, alignment, hiddens, predicted_masks = model.module.beam_sample(
                src, src_len, dict_spk2idx, tgt, config.beam_size,
                padded_mixture)
        else:
            samples, alignment, hiddens, predicted_masks = model.beam_sample(
                src, src_len, dict_spk2idx, tgt, config.beam_size,
                padded_mixture)

        # '''
        # expand the raw mixed-features to topk_max channel.
        src = src.transpose(0, 1)
        siz = src.size()
        assert len(siz) == 3
        # if samples[0][-1] != dict_spk2idx['<EOS>']:
        #     print '*'*40+'\nThe model is far from good. End the evaluation.\n'+'*'*40
        #     break
        topk_max = len(samples[0]) - 1
        x_input_map_multi = torch.unsqueeze(src,
                                            1).expand(siz[0], topk_max, siz[1],
                                                      siz[2])
        if config.WFM:
            feas_tgt = x_input_map_multi.data * WFM_mask

        if not config.use_tas and test_or_valid == 'valid':
            if 1 and len(opt.gpus) > 1:
                ss_loss = model.module.separation_loss(
                    x_input_map_multi,
                    predicted_masks,
                    feas_tgt,
                )
            else:
                ss_loss = model.separation_loss(x_input_map_multi,
                                                predicted_masks, feas_tgt)
            print('loss for ss,this batch:', ss_loss.cpu().item())
            # lera.log({
            #     'ss_loss_' + test_or_valid: ss_loss.cpu().item(),
            # })
            del ss_loss, hiddens

        # '''''
        if batch_idx <= (100 / config.batch_size
                         ):  # only the former batches counts the SDR
            if config.use_tas:
                utils.bss_eval_tas(config,
                                   predicted_masks,
                                   eval_data['multi_spk_fea_list'],
                                   raw_tgt,
                                   eval_data,
                                   dst='batch_output1')
            else:
                predicted_maps = predicted_masks * x_input_map_multi
                utils.bss_eval2(config,
                                predicted_maps,
                                eval_data['multi_spk_fea_list'],
                                raw_tgt,
                                eval_data,
                                dst='batch_output1')
                del predicted_maps
            del predicted_masks, x_input_map_multi
            try:
                #SDR_SUM,SDRi_SUM = np.append(SDR_SUM, bss_test.cal('batch_output1/'))
                sdr_aver_batch, sdri_aver_batch = bss_test.cal(
                    'batch_output1/')
                SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
                SDRi_SUM = np.append(SDRi_SUM, sdri_aver_batch)
            except:  # AssertionError,wrong_info:
                print('Errors in calculating the SDR', wrong_info)
            print('SDR_aver_now:', SDR_SUM.mean())
            print('SDRi_aver_now:', SDRi_SUM.mean())
            # lera.log({'SDR sample'+test_or_valid: SDR_SUM.mean()})
            # lera.log({'SDRi sample'+test_or_valid: SDRi_SUM.mean()})
            # raw_input('Press any key to continue......')
        elif batch_idx == (500 / config.batch_size) + 1 and SDR_SUM.mean(
        ) > best_SDR:  # only record the best SDR once.
            print('Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean()))
            best_SDR = SDR_SUM.mean()
            # save_model(log_path+'checkpoint_bestSDR{}.pt'.format(best_SDR))

        # '''
        candidate += [
            convertToLabels(dict_idx2spk, s, dict_spk2idx['<EOS>'])
            for s in samples
        ]
        # source += raw_src
        reference += raw_tgt
        print('samples:', samples)
        print('can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:],
                                        reference[-1 * config.batch_size:]))
        alignments += [align for align in alignment]
        batch_idx += 1

    score = {}
    result = utils.eval_metrics(reference, candidate, dict_spk2idx, log_path)
    logging_csv([e, updates, result['hamming_loss'], \
                 result['micro_f1'], result['micro_precision'], result['micro_recall']])
    print('hamming_loss: %.8f | micro_f1: %.4f' %
          (result['hamming_loss'], result['micro_f1']))
    score['hamming_loss'] = result['hamming_loss']
    score['micro_f1'] = result['micro_f1']
    return score
Ejemplo n.º 10
0
def eval(epoch,test_or_valid='valid'):
    # config.batch_size=1
    global updates,model
    model.eval()
    # print '\n\n测试的时候请设置config里的batch_size为1!!!please set the batch_size as 1'
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    print(('Test or valid:', test_or_valid))
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX, config.MAX_MIX)
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    # for iii in range(2000):
    while True:
        print(('-' * 30))
        eval_data = next(eval_data_gen)
        if eval_data == False:
            print(('SDR_aver_eval_epoch:', SDR_SUM.mean()))
            print(('SDRi_aver_eval_epoch:', SDRi_SUM.mean()))
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_feas']))

        # raw_tgt = [sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']]
        raw_tgt= eval_data['batch_order']
        feas_tgt = models.rank_feas(raw_tgt, eval_data['multi_spk_fea_list'])  # 这里是目标的图谱

        top_k = len(raw_tgt[0])
        # 要保证底下这几个都是longTensor(长整数)
        # tgt = Variable(torch.from_numpy(np.array([[0]+[dict_spk2idx[spk] for spk in spks]+[dict_spk2idx['<EOS>']] for spks in raw_tgt],dtype=np.int))).transpose(0,1) #转换成数字,然后前后加开始和结束符号。
        tgt = Variable(torch.from_numpy(np.array([[0,1,2,102] for __ in range(config.batch_size)], dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。

        src_len = Variable(torch.LongTensor(config.batch_size).zero_() + mix_speech_len).unsqueeze(0)
        tgt_len = Variable(torch.LongTensor([len(one_spk) for one_spk in eval_data['multi_spk_fea_list']])).unsqueeze(0)
        # tgt_len = Variable(torch.LongTensor(config.batch_size).zero_()+len(eval_data['multi_spk_fea_list'][0])).unsqueeze(0)
        if config.WFM:
            siz = src.size()  # bs,T,F
            assert len(siz) == 3
            # topk_max = config.MAX_MIX  # 最多可能的topk个数
            topk_max = 2  # 最多可能的topk个数
            x_input_map_multi = torch.unsqueeze(src, 1).expand(siz[0], topk_max, siz[1], siz[2]).contiguous().view(-1, siz[1], siz[ 2])  # bs,topk,T,F
            feas_tgt_tmp = feas_tgt.view(siz[0], -1, siz[1], siz[2])

            feas_tgt_square = feas_tgt_tmp * feas_tgt_tmp
            feas_tgt_sum_square = torch.sum(feas_tgt_square, dim=1, keepdim=True).expand(siz[0], topk_max, siz[1], siz[2])
            WFM_mask = feas_tgt_square / (feas_tgt_sum_square + 1e-15)
            feas_tgt = x_input_map_multi.view(siz[0], -1, siz[1], siz[2]).data * WFM_mask  # bs,topk,T,F
            feas_tgt = feas_tgt.view(-1, siz[1], siz[2])  # bs*topk,T,F
            WFM_mask = WFM_mask.cuda()
            del x_input_map_multi

        elif config.PSM:
            siz = src.size()  # bs,T,F
            assert len(siz) == 3
            # topk_max = config.MAX_MIX  # 最多可能的topk个数
            topk_max = 2  # 最多可能的topk个数
            x_input_map_multi = torch.unsqueeze(src, 1).expand(siz[0], topk_max, siz[1], siz[2]).contiguous()  # bs,topk,T,F
            feas_tgt_tmp = feas_tgt.view(siz[0], -1, siz[1], siz[2])

            IRM=feas_tgt_tmp/(x_input_map_multi+1e-15)

            angle_tgt=models.rank_feas(raw_tgt, eval_data['multi_spk_angle_list']).view(siz[0],-1,siz[1],siz[2])
            angle_mix=Variable(torch.from_numpy(np.array(eval_data['mix_angle']))).unsqueeze(1).expand(siz[0], topk_max, siz[1], siz[2]).contiguous()
            ang=np.cos(angle_mix-angle_tgt)
            ang=np.clip(ang,0,None)

            # feas_tgt = x_input_map_multi *np.clip(IRM.numpy()*ang,0,1) # bs,topk,T,F
            # feas_tgt = x_input_map_multi *IRM*ang # bs,topk,T,F
            feas_tgt = feas_tgt.view(siz[0],-1,siz[1],siz[2])*ang # bs,topk,T,F
            feas_tgt = feas_tgt.view(-1, siz[1], siz[2])  # bs*topk,T,F
            del x_input_map_multi

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()
            if config.WFM:
                WFM_mask = WFM_mask.cuda()

        predicted_masks, enc_attn_list = model(src, src_len, tgt, tgt_len,
                                               dict_spk2idx)  # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用

        print('predicted mask size:', predicted_masks.size(),'should be topk,bs,T,F') # topk,bs,T,F
        # try:

        # '''
        # expand the raw mixed-features to topk_max channel.
        src = src.transpose(0, 1)
        siz = src.size()
        assert len(siz) == 3
        # if samples[0][-1] != dict_spk2idx['<EOS>']:
        #     print '*'*40+'\nThe model is far from good. End the evaluation.\n'+'*'*40
        #     break
        topk_max = config.MAX_MIX
        x_input_map_multi = torch.unsqueeze(src, 1).expand(siz[0], topk_max, siz[1], siz[2])

        predicted_masks=predicted_masks.transpose(0, 1)
        # if config.WFM:
        #     feas_tgt = x_input_map_multi.data * WFM_mask

        # 注意,bs是第二维
        assert predicted_masks.shape == x_input_map_multi.shape
        assert predicted_masks.size(0) == config.batch_size

        if 1 and len(opt.gpus) > 1:
            ss_loss,best_pmt = model.module.separation_pit_loss(x_input_map_multi, predicted_masks, feas_tgt, )
        else:
            ss_loss,best_pmt = model.separation_pit_loss(x_input_map_multi, predicted_masks, feas_tgt)
        print(('loss for ss,this batch:', ss_loss.cpu().item()))
        print('best perms for this batch:', best_pmt)
        lera.log({
            'ss_loss_' + test_or_valid: ss_loss.cpu().item(),
        })
        writer.add_scalars('scalar/loss',{'ss_loss_'+test_or_valid:ss_loss.cpu().item()},updates+batch_idx)
        del ss_loss
        if batch_idx>10:
            break

        if False: #this part is to test the checkpoints sequencially.
            batch_idx += 1
            if batch_idx%100==0:
                updates=updates+1000
                opt.restore='/data1/shijing_data/2020-02-14-04:58:17/Transformer_PIT_{}.pt'.format(updates)
                print('loading checkpoint...\n', opt.restore)
                checkpoints = torch.load(opt.restore)
                model.module.load_state_dict(checkpoints['model'])
                break
            continue
        # '''''
        if 0 and batch_idx <= (500 / config.batch_size):  # only the former batches counts the SDR
            predicted_maps = predicted_masks * x_input_map_multi
            predicted_maps = predicted_maps.view(-1,mix_speech_len,speech_fre)
            # predicted_maps=Variable(feas_tgt)
            utils.bss_eval2(config, predicted_maps, eval_data['multi_spk_fea_list'], raw_tgt, eval_data,
                            dst='batch_output_test')
            # utils.bss_eval(config, predicted_maps, eval_data['multi_spk_fea_list'], raw_tgt, eval_data,
            #                 dst='batch_output_test')
            del predicted_maps, predicted_masks, x_input_map_multi
            try:
                sdr_aver_batch, sdri_aver_batch=  bss_test.cal('batch_output_test/')
                SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
                SDRi_SUM = np.append(SDRi_SUM, sdri_aver_batch)
            except(AssertionError):
                print('Errors in calculating the SDR')
            print(('SDR_aver_now:', SDR_SUM.mean()))
            print(('SRi_aver_now:', SDRi_SUM.mean()))
            lera.log({'SDR sample'+test_or_valid: SDR_SUM.mean()})
            lera.log({'SDRi sample'+test_or_valid: SDRi_SUM.mean()})
            writer.add_scalars('scalar/loss',{'SDR_sample_'+test_or_valid:sdr_aver_batch},updates)
            # raw_input('Press any key to continue......')
        elif batch_idx == (200 / config.batch_size) + 1 and SDR_SUM.mean() > best_SDR:  # only record the best SDR once.
            print(('Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean())))
            best_SDR = SDR_SUM.mean()
            # save_model(log_path+'checkpoint_bestSDR{}.pt'.format(best_SDR))

        # '''
        # candidate += [convertToLabels(dict_idx2spk, s, dict_spk2idx['<EOS>']) for s in samples]
        # source += raw_src
        # reference += raw_tgt
        # print(('samples:', samples))
        # print(('can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:], reference[-1 * config.batch_size:])))
        # alignments += [align for align in alignment]
        batch_idx += 1

        result = utils.eval_metrics(reference, candidate, dict_spk2idx, log_path)
        print(('hamming_loss: %.8f | micro_f1: %.4f |recall: %.4f | precision: %.4f'
                   % (result['hamming_loss'], result['micro_f1'], result['micro_recall'], result['micro_precision'], )))
Ejemplo n.º 11
0
def eval(epoch, test_or_valid='valid'):
    # config.batch_size=1
    global updates, model
    model.eval()
    # print '\n\n测试的时候请设置config里的batch_size为1!!!please set the batch_size as 1'
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    print(('Test or valid:', test_or_valid))
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX,
                                 config.MAX_MIX)
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    # for iii in range(2000):
    while True:
        print(('-' * 30))
        eval_data = next(eval_data_gen)
        if eval_data == False:
            print(('SDR_aver_eval_epoch:', SDR_SUM.mean()))
            print(('SDRi_aver_eval_epoch:', SDRi_SUM.mean()))
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_complex_two_channel'])
                       )  # bs,T,F,2 both real and imag values
        raw_tgt = eval_data['batch_order']
        feas_tgt = models.rank_feas(
            raw_tgt,
            eval_data['multi_spk_wav_list'])  # 这里是目标的图谱,bs*Topk,time_len

        padded_mixture, mixture_lengths, padded_source = eval_data['tas_zip']
        padded_mixture = torch.from_numpy(padded_mixture).float()
        mixture_lengths = torch.from_numpy(mixture_lengths)
        padded_source = torch.from_numpy(padded_source).float()

        padded_mixture = padded_mixture.cuda().transpose(0, 1)
        mixture_lengths = mixture_lengths.cuda()
        padded_source = padded_source.cuda()

        # 要保证底下这几个都是longTensor(长整数)
        tgt = Variable(
            torch.from_numpy(
                np.array([[0, 1, 2, 102] for __ in range(config.batch_size)],
                         dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。

        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor([
                len(one_spk) for one_spk in eval_data['multi_spk_fea_list']
            ])).unsqueeze(0)

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()

        model.zero_grad()
        if config.use_center_loss:
            center_loss.zero_grad()

        multi_mask_real, multi_mask_imag, enc_attn_list = model(
            src, src_len, tgt, tgt_len,
            dict_spk2idx)  # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
        multi_mask_real = multi_mask_real.transpose(0, 1)
        multi_mask_imag = multi_mask_imag.transpose(0, 1)
        src_real = src[:, :, :, 0].transpose(0, 1)  # bs,T,F
        src_imag = src[:, :, :, 1].transpose(0, 1)  # bs,T,F
        print('mask size for real/imag:',
              multi_mask_real.size())  # bs,topk,T,F, 已经压缩过了
        print('mixture size for real/imag:', src_real.size())  # bs,T,F

        predicted_maps0_real = multi_mask_real[:,
                                               0] * src_real - multi_mask_imag[:,
                                                                               0] * src_imag  #bs,T,F
        predicted_maps0_imag = multi_mask_real[:,
                                               0] * src_imag + multi_mask_imag[:,
                                                                               0] * src_real  #bs,T,F
        predicted_maps1_real = multi_mask_real[:,
                                               1] * src_real - multi_mask_imag[:,
                                                                               1] * src_imag  #bs,T,F
        predicted_maps1_imag = multi_mask_real[:,
                                               1] * src_imag + multi_mask_imag[:,
                                                                               1] * src_real  #bs,T,F

        stft_matrix_spk0 = torch.cat((predicted_maps0_real.unsqueeze(-1),
                                      predicted_maps0_imag.unsqueeze(-1)),
                                     3).transpose(1, 2)  # bs,F,T,2
        stft_matrix_spk1 = torch.cat((predicted_maps1_real.unsqueeze(-1),
                                      predicted_maps1_imag.unsqueeze(-1)),
                                     3).transpose(1, 2)  # bs,F,T,2
        wav_spk0 = models.istft_irfft(stft_matrix_spk0,
                                      length=config.MAX_LEN,
                                      hop_length=config.FRAME_SHIFT,
                                      win_length=config.FRAME_LENGTH,
                                      window='hann')
        wav_spk1 = models.istft_irfft(stft_matrix_spk1,
                                      length=config.MAX_LEN,
                                      hop_length=config.FRAME_SHIFT,
                                      win_length=config.FRAME_LENGTH,
                                      window='hann')
        predict_wav = torch.cat((wav_spk0.unsqueeze(1), wav_spk1.unsqueeze(1)),
                                1)  # bs,topk,time_len
        if 1 and len(opt.gpus) > 1:
            ss_loss, pmt_list, max_snr_idx, *__ = model.module.separation_tas_loss(
                padded_mixture, predict_wav, padded_source, mixture_lengths)
        else:
            ss_loss, pmt_list, max_snr_idx, *__ = model.separation_tas_loss(
                padded_mixture, predict_wav, padded_source, mixture_lengths)

        best_pmt = [
            list(pmt_list[int(mm)].data.cpu().numpy()) for mm in max_snr_idx
        ]
        print('loss for SS,this batch:', ss_loss.cpu().item())
        print('best perms for this batch:', best_pmt)
        writer.add_scalars('scalar/loss', {'ss_loss': ss_loss.cpu().item()},
                           updates)
        lera.log({
            'ss_loss_' + test_or_valid: ss_loss.cpu().item(),
        })
        writer.add_scalars('scalar/loss',
                           {'ss_loss_' + test_or_valid: ss_loss.cpu().item()},
                           updates + batch_idx)
        del ss_loss
        # if batch_idx>10:
        #     break

        if False:  #this part is to test the checkpoints sequencially.
            batch_idx += 1
            if batch_idx % 100 == 0:
                updates = updates + 1000
                opt.restore = '/data1/shijing_data/2020-02-14-04:58:17/Transformer_PIT_{}.pt'.format(
                    updates)
                print('loading checkpoint...\n', opt.restore)
                checkpoints = torch.load(opt.restore)
                model.module.load_state_dict(checkpoints['model'])
                break
            continue
        # '''''
        if 1 and batch_idx <= (500 / config.batch_size):
            utils.bss_eval_tas(config,
                               predict_wav,
                               eval_data['multi_spk_fea_list'],
                               raw_tgt,
                               eval_data,
                               dst=log_path + 'batch_output')
            sdr_aver_batch, snri_aver_batch = bss_test.cal(log_path +
                                                           'batch_output/')
            lera.log({'SDR sample': sdr_aver_batch})
            lera.log({'SI-SNRi sample': snri_aver_batch})
            writer.add_scalars('scalar/loss', {
                'SDR_sample': sdr_aver_batch,
                'SDRi_sample': snri_aver_batch
            }, updates)
            SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
            SDRi_SUM = np.append(SDRi_SUM, snri_aver_batch)
            print(('SDR_aver_now:', SDR_SUM.mean()))
            print(('SNRi_aver_now:', SDRi_SUM.mean()))

        batch_idx += 1
        if batch_idx > 100:
            break
        result = utils.eval_metrics(reference, candidate, dict_spk2idx,
                                    log_path)
        print((
            'hamming_loss: %.8f | micro_f1: %.4f |recall: %.4f | precision: %.4f'
            % (
                result['hamming_loss'],
                result['micro_f1'],
                result['micro_recall'],
                result['micro_precision'],
            )))
Ejemplo n.º 12
0
def eval(epoch, test_or_valid='train'):
    # config.batch_size=1
    global updates, model
    model.eval()
    # print '\n\n测试的时候请设置config里的batch_size为1!!!please set the batch_size as 1'
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    print(('Test or valid:', test_or_valid))
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX,
                                 config.MAX_MIX)
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    # for iii in range(2000):
    while True:
        print(('-' * 30))
        eval_data = next(eval_data_gen)
        if eval_data == False:
            print(('SDR_aver_eval_epoch:', SDR_SUM.mean()))
            print(('SDRi_aver_eval_epoch:', SDRi_SUM.mean()))
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch

        src = Variable(torch.from_numpy(eval_data['mix_feas']))
        # raw_tgt = [spk.keys() for spk in eval_data['multi_spk_fea_list']]
        # raw_tgt = [sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']]
        raw_tgt = eval_data['batch_order']
        feas_tgt = models.rank_feas(
            raw_tgt,
            eval_data['multi_spk_wav_list'])  # 这里是目标的图谱,bs*Topk,time_len

        padded_mixture, mixture_lengths, padded_source = eval_data['tas_zip']
        padded_mixture = torch.from_numpy(padded_mixture).float()
        mixture_lengths = torch.from_numpy(mixture_lengths)
        padded_source = torch.from_numpy(padded_source).float()

        padded_mixture = padded_mixture.cuda().transpose(0, 1)
        mixture_lengths = mixture_lengths.cuda()
        padded_source = padded_source.cuda()

        # 要保证底下这几个都是longTensor(长整数)
        tgt_max_len = config.MAX_MIX + 2  # with bos and eos.
        # tgt = Variable(torch.from_numpy(np.array(
        #     [[0] + [dict_spk2idx[spk] for spk in spks] + (tgt_max_len - len(spks) - 1) * [dict_spk2idx['<EOS>']] for
        #      spks in raw_tgt], dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。
        tgt = Variable(
            torch.from_numpy(
                np.array([[0, 1, 2, 102] for __ in range(config.batch_size)],
                         dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。
        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor([
                len(one_spk) for one_spk in eval_data['multi_spk_fea_list']
            ])).unsqueeze(0)

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()

        model.zero_grad()
        if config.use_center_loss:
            center_loss.zero_grad()

        multi_mask, enc_attn_list = model(
            src, src_len, tgt, tgt_len,
            dict_spk2idx)  # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
        multi_mask = multi_mask.transpose(0, 1)
        print('mask size:', multi_mask.size())  # bs,topk,T,F

        predicted_maps0_spectrogram = multi_mask[:, 0] * src.transpose(
            0, 1)  #bs,T,F
        predicted_maps1_spectrogram = multi_mask[:, 1] * src.transpose(
            0, 1)  #bs,T,F

        if True:  # Analyze the optimal assignments
            predicted_spectrogram = torch.cat([
                predicted_maps0_spectrogram.unsqueeze(1),
                predicted_maps1_spectrogram.unsqueeze(1)
            ], 1)
            feas_tgt_tmp = models.rank_feas(
                raw_tgt,
                eval_data['multi_spk_fea_list'])  # 这里是目标的图谱,bs*Topk,len,fre
            src = src.transpose(0, 1)
            siz = src.size()  # bs,T,F
            assert len(siz) == 3
            # topk_max = config.MAX_MIX  # 最多可能的topk个数
            topk_max = 2  # 最多可能的topk个数
            x_input_map_multi = torch.unsqueeze(src, 1).expand(
                siz[0], topk_max, siz[1], siz[2]).contiguous()  # bs,topk,T,F
            feas_tgt_tmp = feas_tgt_tmp.view(siz[0], -1, siz[1], siz[2])

            angle_tgt = models.rank_feas(
                raw_tgt, eval_data['multi_spk_angle_list']).view(
                    siz[0], -1, siz[1], siz[2])  # bs,topk,T,F
            angle_mix = Variable(
                torch.from_numpy(np.array(
                    eval_data['mix_angle']))).unsqueeze(1).expand(
                        siz[0], topk_max, siz[1], siz[2]).contiguous()
            ang = np.cos(angle_mix - angle_tgt)
            ang = np.clip(ang, 0, None)

            feas_tgt_tmp = feas_tgt_tmp.view(siz[0], -1, siz[1],
                                             siz[2]) * ang  # bs,topk,T,F
            feas_tgt_tmp = feas_tgt_tmp.cuda()
            del x_input_map_multi
            src = src.transpose(0, 1)
            MSE_func = nn.MSELoss().cuda()
            best_perms_this_batch = []
            for bs_idx in range(siz[0]):
                best_perms_this_sample = []
                for tt in range(siz[1]):  # 对每一帧
                    tar = feas_tgt_tmp[bs_idx, :, tt]  #topk,F
                    est = predicted_spectrogram[bs_idx, :, tt]  #topk,F
                    best_loss_mse_this_batch = -1
                    for idx, per in enumerate([[0, 1], [1, 0]]):
                        if idx == 0:
                            best_loss_mse_this_batch = MSE_func(est[per], tar)
                            perm_this_frame = per
                            predicted_spectrogram[bs_idx, :, tt] = est[per]
                        else:
                            loss = MSE_func(est[per], tar)
                            if loss <= best_loss_mse_this_batch:
                                best_loss_mse_this_batch = loss
                                perm_this_frame = per
                                predicted_spectrogram[bs_idx, :, tt] = est[per]

                    best_perms_this_sample.append(perm_this_frame)
                best_perms_this_batch.append(best_perms_this_sample)
            print(
                'different assignment ratio:',
                np.mean(np.min(
                    np.array(best_perms_this_batch).sum(1) / 751, 1)))
            # predicted_maps0_spectrogram = predicted_spectrogram[:,0]
            # predicted_maps1_spectrogram = predicted_spectrogram[:,1]

        _mix_spec = eval_data['mix_phase']  # bs,T,F,2
        angle_mix = np.angle(_mix_spec)
        predicted_maps0_real = predicted_maps0_spectrogram * torch.from_numpy(
            np.cos(angle_mix)).cuda()  # e(ix) = cosx + isin x
        predicted_maps0_imag = predicted_maps0_spectrogram * torch.from_numpy(
            np.sin(angle_mix)).cuda()  # e(ix) = cosx + isin x
        predicted_maps1_real = predicted_maps1_spectrogram * torch.from_numpy(
            np.cos(angle_mix)).cuda()  # e(ix) = cosx + isin x
        predicted_maps1_imag = predicted_maps1_spectrogram * torch.from_numpy(
            np.sin(angle_mix)).cuda()  # e(ix) = cosx + isin x

        stft_matrix_spk0 = torch.cat((predicted_maps0_real.unsqueeze(-1),
                                      predicted_maps0_imag.unsqueeze(-1)),
                                     3).transpose(1, 2)  # bs,F,T,2
        stft_matrix_spk1 = torch.cat((predicted_maps1_real.unsqueeze(-1),
                                      predicted_maps1_imag.unsqueeze(-1)),
                                     3).transpose(1, 2)  # bs,F,T,2
        wav_spk0 = models.istft_irfft(stft_matrix_spk0,
                                      length=config.MAX_LEN,
                                      hop_length=config.FRAME_SHIFT,
                                      win_length=config.FRAME_LENGTH,
                                      window='hann')
        wav_spk1 = models.istft_irfft(stft_matrix_spk1,
                                      length=config.MAX_LEN,
                                      hop_length=config.FRAME_SHIFT,
                                      win_length=config.FRAME_LENGTH,
                                      window='hann')
        predict_wav = torch.cat((wav_spk0.unsqueeze(1), wav_spk1.unsqueeze(1)),
                                1)  # bs,topk,time_len
        if 1 and len(opt.gpus) > 1:
            ss_loss, pmt_list, max_snr_idx, *__ = model.module.separation_tas_loss(
                padded_mixture, predict_wav, padded_source, mixture_lengths)
        else:
            ss_loss, pmt_list, max_snr_idx, *__ = model.separation_tas_loss(
                padded_mixture, predict_wav, padded_source, mixture_lengths)

        best_pmt = [
            list(pmt_list[int(mm)].data.cpu().numpy()) for mm in max_snr_idx
        ]
        print('loss for SS,this batch:', ss_loss.cpu().item())
        print('best perms for this batch:', best_pmt)
        writer.add_scalars('scalar/loss', {'ss_loss': ss_loss.cpu().item()},
                           updates)
        lera.log({
            'ss_loss_' + test_or_valid: ss_loss.cpu().item(),
        })
        writer.add_scalars('scalar/loss',
                           {'ss_loss_' + test_or_valid: ss_loss.cpu().item()},
                           updates + batch_idx)
        del ss_loss
        # if batch_idx>10:
        #     break

        if False:  #this part is to test the checkpoints sequencially.
            batch_idx += 1
            if batch_idx % 100 == 0:
                updates = updates + 1000
                opt.restore = '/data1/shijing_data/2020-02-14-04:58:17/Transformer_PIT_{}.pt'.format(
                    updates)
                print('loading checkpoint...\n', opt.restore)
                checkpoints = torch.load(opt.restore)
                model.module.load_state_dict(checkpoints['model'])
                break
            continue
        # '''''
        if 1 and batch_idx <= (500 / config.batch_size):
            utils.bss_eval_tas(config,
                               predict_wav,
                               eval_data['multi_spk_fea_list'],
                               raw_tgt,
                               eval_data,
                               dst=log_path + 'batch_output')
            sdr_aver_batch, snri_aver_batch = bss_test.cal(log_path +
                                                           'batch_output/')
            lera.log({'SDR sample': sdr_aver_batch})
            lera.log({'SI-SNRi sample': snri_aver_batch})
            writer.add_scalars('scalar/loss', {
                'SDR_sample': sdr_aver_batch,
                'SDRi_sample': snri_aver_batch
            }, updates)
            SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
            SDRi_SUM = np.append(SDRi_SUM, snri_aver_batch)
            print(('SDR_aver_now:', SDR_SUM.mean()))
            print(('SNRi_aver_now:', SDRi_SUM.mean()))

        batch_idx += 1
        if batch_idx > 100:
            break
        result = utils.eval_metrics(reference, candidate, dict_spk2idx,
                                    log_path)
        print((
            'hamming_loss: %.8f | micro_f1: %.4f |recall: %.4f | precision: %.4f'
            % (
                result['hamming_loss'],
                result['micro_f1'],
                result['micro_recall'],
                result['micro_precision'],
            )))
Ejemplo n.º 13
0
def eval(epoch):
    # config.batch_size=1
    model.eval()
    # print '\n\n测试的时候请设置config里的batch_size为1!!!please set the batch_size as 1'
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    test_or_valid = 'test'
    test_or_valid = 'valid'
    # test_or_valid = 'train'
    print(('Test or valid:', test_or_valid))
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX, config.MAX_MIX)
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    # for iii in range(2000):
    while True:
        print(('-' * 30))
        eval_data = next(eval_data_gen)
        if eval_data == False:
            print(('SDR_aver_eval_epoch:', SDR_SUM.mean()))
            print(('SDRi_aver_eval_epoch:', SDRi_SUM.mean()))
            break  # 如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_feas']))

        # raw_tgt = [sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']]
        raw_tgt= eval_data['batch_order']
        feas_tgt = models.rank_feas(raw_tgt, eval_data['multi_spk_fea_list'])  # 这里是目标的图谱

        top_k = len(raw_tgt[0])
        # 要保证底下这几个都是longTensor(长整数)
        # tgt = Variable(torch.from_numpy(np.array([[0]+[dict_spk2idx[spk] for spk in spks]+[dict_spk2idx['<EOS>']] for spks in raw_tgt],dtype=np.int))).transpose(0,1) #转换成数字,然后前后加开始和结束符号。
        tgt = Variable(torch.from_numpy(np.array([[0,1,2,102] for __ in range(config.batch_size)], dtype=np.int))).transpose(0, 1)  # 转换成数字,然后前后加开始和结束符号。

        src_len = Variable(torch.LongTensor(config.batch_size).zero_() + mix_speech_len).unsqueeze(0)
        tgt_len = Variable(torch.LongTensor([len(one_spk) for one_spk in eval_data['multi_spk_fea_list']])).unsqueeze(0)
        # tgt_len = Variable(torch.LongTensor(config.batch_size).zero_()+len(eval_data['multi_spk_fea_list'][0])).unsqueeze(0)
        if config.WFM:
            tmp_size = feas_tgt.size()
            assert len(tmp_size) == 3
            feas_tgt_square = feas_tgt * feas_tgt
            feas_tgt_sum_square = torch.sum(feas_tgt_square, dim=0, keepdim=True).expand(tmp_size)
            WFM_mask = feas_tgt_square / (feas_tgt_sum_square + 1e-15)

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()
            if config.WFM:
                WFM_mask = WFM_mask.cuda()

        # try:
        if 1 and len(opt.gpus) > 1:
            samples,  predicted_masks = model.module.pit_sample(src, src_len, dict_spk2idx, tgt,
                                                                                    beam_size=config.beam_size)
        else:
            samples,  predicted_masks = model.pit_sample(src, src_len, dict_spk2idx, tgt,
                                                                             beam_size=config.beam_size)
        samples=samples.max(2)[1].data.cpu().numpy()
        # except:
        #     continue

        # '''
        # expand the raw mixed-features to topk_max channel.
        src = src.transpose(0, 1)
        siz = src.size()
        assert len(siz) == 3
        # if samples[0][-1] != dict_spk2idx['<EOS>']:
        #     print '*'*40+'\nThe model is far from good. End the evaluation.\n'+'*'*40
        #     break
        topk_max = len(samples[0]) - 1
        x_input_map_multi = torch.unsqueeze(src, 1).expand(siz[0], topk_max, siz[1], siz[2])
        if 1 and config.WFM:
            feas_tgt = x_input_map_multi.data * WFM_mask

        if test_or_valid != 'test':
            if 1 and len(opt.gpus) > 1:
                ss_loss = model.module.separation_loss(x_input_map_multi, predicted_masks, feas_tgt, )
            else:
                ss_loss = model.separation_loss(x_input_map_multi, predicted_masks, feas_tgt)
            print(('loss for ss,this batch:', ss_loss.cpu().item()))
            lera.log({
                'ss_loss_' + test_or_valid: ss_loss.cpu().item(),
            })
            del ss_loss

        # '''''
        if 1 and batch_idx <= (500 / config.batch_size):  # only the former batches counts the SDR
            predicted_maps = predicted_masks * x_input_map_multi
            # predicted_maps=Variable(feas_tgt)
            utils.bss_eval2(config, predicted_maps, eval_data['multi_spk_fea_list'], raw_tgt, eval_data,
                            dst='batch_output_test')
            del predicted_maps, predicted_masks, x_input_map_multi
            try:
                sdr_aver_batch, sdri_aver_batch=  bss_test.cal('batch_output_test/')
                SDR_SUM = np.append(SDR_SUM, sdr_aver_batch)
                SDRi_SUM = np.append(SDRi_SUM, sdri_aver_batch)
            except(AssertionError):
                print('Errors in calculating the SDR')
            print(('SDR_aver_now:', SDR_SUM.mean()))
            print(('SDRi_aver_now:', SDRi_SUM.mean()))
            lera.log({'SDR sample'+test_or_valid: SDR_SUM.mean()})
            lera.log({'SDRi sample'+test_or_valid: SDRi_SUM.mean()})
            writer.add_scalars('scalar/loss',{'SDR_sample_'+test_or_valid:sdr_aver_batch},updates)
            # raw_input('Press any key to continue......')
        elif batch_idx == (200 / config.batch_size) + 1 and SDR_SUM.mean() > best_SDR:  # only record the best SDR once.
            print(('Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean())))
            best_SDR = SDR_SUM.mean()
            # save_model(log_path+'checkpoint_bestSDR{}.pt'.format(best_SDR))

        # '''
        candidate += [convertToLabels(dict_idx2spk, s, dict_spk2idx['<EOS>']) for s in samples]
        # source += raw_src
        reference += raw_tgt
        print(('samples:', samples))
        print(('can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:], reference[-1 * config.batch_size:])))
        # alignments += [align for align in alignment]
        batch_idx += 1

        result = utils.eval_metrics(reference, candidate, dict_spk2idx, log_path)
        print(('hamming_loss: %.8f | micro_f1: %.4f |recall: %.4f | precision: %.4f'
                   % (result['hamming_loss'], result['micro_f1'], result['micro_recall'], result['micro_precision'], )))

    score = {}
    result = utils.eval_metrics(reference, candidate, dict_spk2idx, log_path)
    logging_csv([e, updates, result['hamming_loss'], \
                 result['micro_f1'], result['micro_precision'], result['micro_recall'],SDR_SUM.mean()])
    print(('hamming_loss: %.8f | micro_f1: %.4f'
          % (result['hamming_loss'], result['micro_f1'])))
    score['hamming_loss'] = result['hamming_loss']
    score['micro_f1'] = result['micro_f1']
    1/0
    return score
Ejemplo n.º 14
0
def eval(epoch):
    model.eval()
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    test_or_valid = 'valid'
    print 'Test or valid:', test_or_valid
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX,
                                 config.MAX_MIX)
    # for raw_src, src, src_len, raw_tgt, tgt, tgt_len in validloader:
    SDR_SUM = np.array([])
    batch_idx = 0
    global best_SDR
    while True:
        # for ___ in range(2):
        print '-' * 30
        eval_data = eval_data_gen.next()
        if eval_data == False:
            break  #如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_feas']))

        raw_tgt = [
            sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']
        ]
        top_k = len(raw_tgt[0])
        # 要保证底下这几个都是longTensor(长整数)
        # tgt = Variable(torch.from_numpy(np.array([[0]+[dict_spk2idx[spk] for spk in spks]+[dict_spk2idx['<EOS>']] for spks in raw_tgt],dtype=np.int))).transpose(0,1) #转换成数字,然后前后加开始和结束符号。
        tgt = Variable(torch.ones(
            top_k + 2, config.batch_size))  # 这里随便给一个tgt,为了测试阶段tgt的名字无所谓其实。

        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            len(eval_data['multi_spk_fea_list'][0])).unsqueeze(0)
        feas_tgt = models.rank_feas(raw_tgt,
                                    eval_data['multi_spk_fea_list'])  #这里是目标的图谱
        relitu(mix_speech_len, speech_fre, feas_tgt.numpy()[0, 0].transpose())
        relitu(mix_speech_len, speech_fre, feas_tgt.numpy()[0, 1].transpose())
        # 1/0

        if config.WFM:
            tmp_size = feas_tgt.size()
            assert len(tmp_size) == 4
            feas_tgt_sum = torch.sum(feas_tgt, dim=1, keepdim=True)
            feas_tgt_sum_square = (feas_tgt_sum *
                                   feas_tgt_sum).expand(tmp_size)
            feas_tgt_square = feas_tgt * feas_tgt
            WFM_mask = feas_tgt_square / feas_tgt_sum_square

        if use_cuda:
            src = src.cuda()
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()
            if config.WFM:
                WFM_mask = WFM_mask.cuda()
        if len(opt.gpus) > 1:
            samples, alignment = model.module.sample(src, src_len)
        else:
            samples, alignment, hiddens, predicted_masks = model.beam_sample(
                src, src_len, dict_spk2idx, tgt, beam_size=config.beam_size)
            # try:
            #     samples, alignment, hiddens, predicted_masks = model.beam_sample(src, src_len, dict_spk2idx, tgt, beam_size=config.beam_size)
            # except Exception,info:
            #     print '**************Error occurs here************:', info
            #     continue

        if config.top1:
            predicted_masks = torch.cat([predicted_masks, 1 - predicted_masks],
                                        1)

        # '''
        # expand the raw mixed-features to topk channel.
        siz = src.size()
        assert len(siz) == 3
        topk = feas_tgt.size()[1]
        x_input_map_multi = torch.unsqueeze(src,
                                            1).expand(siz[0], topk, siz[1],
                                                      siz[2])
        if config.WFM:
            feas_tgt = x_input_map_multi.data * WFM_mask
        ss_loss = model.separation_loss(x_input_map_multi, predicted_masks,
                                        feas_tgt)
        print 'loss for ss,this batch:', ss_loss.data[0]
        del ss_loss, hiddens

        # '''''
        if batch_idx <= (3000 / config.batch_size
                         ):  #only the former batches counts the SDR
            predicted_maps = predicted_masks * x_input_map_multi
            # predicted_maps=Variable(feas_tgt)
            utils.bss_eval(config,
                           predicted_maps,
                           eval_data['multi_spk_fea_list'],
                           raw_tgt,
                           eval_data,
                           dst='batch_output23jo')
            del predicted_maps, predicted_masks, x_input_map_multi
            SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_output23jo/'))
            print 'SDR_aver_now:', SDR_SUM.mean()
            # 1/0
            raw_input('Press any key to continue......')
            continue
        elif batch_idx == (500 / config.batch_size) + 1 and SDR_SUM.mean(
        ) > best_SDR:  #only record the best SDR once.
            print 'Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean())
            best_SDR = SDR_SUM.mean()
            # save_model(log_path+'checkpoint_bestSDR{}.pt'.format(best_SDR))

        # '''
        candidate += [
            convertToLabels(dict_idx2spk, s, dict_spk2idx['<EOS>'])
            for s in samples
        ]
        # source += raw_src
        reference += raw_tgt
        print 'samples:', samples
        print 'can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:],
                                        reference[-1 * config.batch_size:])
        alignments += [align for align in alignment]
        batch_idx += 1

    if opt.unk:
        cands = []
        for s, c, align in zip(source, candidate, alignments):
            cand = []
            for word, idx in zip(c, align):
                if word == dict.UNK_WORD and idx < len(s):
                    try:
                        cand.append(s[idx])
                    except:
                        cand.append(word)
                        print("%d %d\n" % (len(s), idx))
                else:
                    cand.append(word)
            cands.append(cand)
        candidate = cands

    score = {}
    result = utils.eval_metrics(reference, candidate, dict_spk2idx, log_path)
    logging_csv([e, updates, result['hamming_loss'], \
                result['micro_f1'], result['micro_precision'], result['micro_recall']])
    print('hamming_loss: %.8f | micro_f1: %.4f' %
          (result['hamming_loss'], result['micro_f1']))
    score['hamming_loss'] = result['hamming_loss']
    score['micro_f1'] = result['micro_f1']
    return score