コード例 #1
0
def train(epoch):
    e = epoch
    model.train()

    if config.schedule:
        scheduler.step()
        print("Decaying learning rate to %g" % scheduler.get_lr()[0])
        if config.is_dis:
            scheduler_dis.step()
        lera.log({
            'lr': scheduler.get_lr()[0],
        })

    if opt.model == 'gated':
        model.current_epoch = epoch

    global e, updates, total_loss, start_time, report_total, total_loss_sgm, total_loss_ss
    if config.MLMSE:
        global Var

    train_data_gen = prepare_data('once', 'train')
    # for raw_src, src, src_len, raw_tgt, tgt, tgt_len in trainloader:
    while True:
        train_data = train_data_gen.next()
        if train_data == False:
            break  #如果这个epoch的生成器没有数据了,直接进入下一个epoch

        src = Variable(torch.from_numpy(train_data['mix_feas']))
        # raw_tgt = [spk.keys() for spk in train_data['multi_spk_fea_list']]
        raw_tgt = [
            sorted(spk.keys()) for spk in train_data['multi_spk_fea_list']
        ]
        feas_tgt = models.rank_feas(
            raw_tgt, train_data['multi_spk_fea_list'])  #这里是目标的图谱

        # 要保证底下这几个都是longTensor(长整数)
        tgt = Variable(
            torch.from_numpy(
                np.array([[0] + [dict_spk2idx[spk]
                                 for spk in spks] + [dict_spk2idx['<EOS>']]
                          for spks in raw_tgt],
                         dtype=np.int))).transpose(0, 1)  #转换成数字,然后前后加开始和结束符号。
        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            len(train_data['multi_spk_fea_list'][0])).unsqueeze(0)
        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()

        model.zero_grad()
        # optim.optimizer.zero_grad()
        outputs, targets, multi_mask = model(
            src, src_len, tgt,
            tgt_len)  #这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
        print 'mask size:', multi_mask.size()

        if 1 and len(opt.gpus) > 1:
            sgm_loss, num_total, num_correct = model.module.compute_loss(
                outputs, targets, opt.memory)
        else:
            sgm_loss, num_total, num_correct = model.compute_loss(
                outputs, targets, opt.memory)
        print 'loss for SGM,this batch:', sgm_loss.data[0] / num_total

        src = src.transpose(0, 1)
        # expand the raw mixed-features to topk channel.
        siz = src.size()
        assert len(siz) == 3
        topk = feas_tgt.size()[1]
        x_input_map_multi = torch.unsqueeze(src,
                                            1).expand(siz[0], topk, siz[1],
                                                      siz[2])
        multi_mask = multi_mask.transpose(0, 1)

        if 1 and len(opt.gpus) > 1:
            if config.MLMSE:
                Var = model.module.update_var(x_input_map_multi, multi_mask,
                                              feas_tgt)
                lera.log_image(u'Var weight',
                               Var.data.cpu().numpy().reshape(
                                   config.speech_fre, config.speech_fre,
                                   1).repeat(3, 2),
                               clip=(-1, 1))
                ss_loss = model.module.separation_loss(x_input_map_multi,
                                                       multi_mask, feas_tgt,
                                                       Var)
            else:
                ss_loss = model.module.separation_loss(x_input_map_multi,
                                                       multi_mask, feas_tgt)
        else:
            ss_loss = model.separation_loss(x_input_map_multi, multi_mask,
                                            feas_tgt)

        loss = sgm_loss + ss_loss
        # dis_loss model
        if config.is_dis:
            dis_loss = models.loss.dis_loss(config, topk, model_dis,
                                            x_input_map_multi, multi_mask,
                                            feas_tgt, func_dis)
            loss = loss + dis_loss
            # print 'dis_para',model_dis.parameters().next()[0]
            # print 'ss_para',model.parameters().next()[0]

        loss.backward()
        # print 'totallllllllllll loss:',loss
        total_loss_sgm += sgm_loss.data[0]
        total_loss_ss += ss_loss.data[0]
        lera.log({
            'sgm_loss': sgm_loss.data[0],
            'ss_loss': ss_loss.data[0],
        })
        total_loss += loss.data[0]
        report_total += num_total
        optim.step()
        if config.is_dis:
            optim_dis.step()

        updates += 1
        if updates % 30 == 0:
            logging(
                "time: %6.3f, epoch: %3d, updates: %8d, train loss this batch: %6.3f,sgm loss: %6.6f,ss loss: %6.6f\n"
                % (time.time() - start_time, epoch, updates, loss / num_total,
                   total_loss_sgm / 30.0, total_loss_ss / 30.0))
            total_loss_sgm, total_loss_ss = 0, 0

        # continue

        if 0 or updates % config.eval_interval == 0 and epoch > 1:
            logging(
                "time: %6.3f, epoch: %3d, updates: %8d, train loss: %6.5f\n" %
                (time.time() - start_time, epoch, updates,
                 total_loss / report_total))
            print('evaluating after %d updates...\r' % updates)
            score = eval(epoch)
            for metric in config.metric:
                scores[metric].append(score[metric])
                if metric == 'micro_f1' and score[metric] >= max(
                        scores[metric]):
                    save_model(log_path + 'best_' + metric + '_checkpoint.pt')
                if metric == 'hamming_loss' and score[metric] <= min(
                        scores[metric]):
                    save_model(log_path + 'best_' + metric + '_checkpoint.pt')

            model.train()
            total_loss = 0
            start_time = 0
            report_total = 0

        if updates % config.save_interval == 1:
            save_model(log_path +
                       'checkpoint_v2_withdis{}.pt'.format(config.is_dis))
コード例 #2
0
def eval(epoch):
    model.eval()
    reference, candidate, source, alignments = [], [], [], []
    e = epoch
    test_or_valid = 'valid'
    print 'Test or valid:', test_or_valid
    eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX,
                                 config.MAX_MIX)
    # for raw_src, src, src_len, raw_tgt, tgt, tgt_len in validloader:
    SDR_SUM = np.array([])
    batch_idx = 0
    global best_SDR, Var
    while True:
        # for ___ in range(2):
        print '-' * 30
        eval_data = eval_data_gen.next()
        if eval_data == False:
            break  #如果这个epoch的生成器没有数据了,直接进入下一个epoch
        src = Variable(torch.from_numpy(eval_data['mix_feas']))

        raw_tgt = [
            sorted(spk.keys()) for spk in eval_data['multi_spk_fea_list']
        ]
        top_k = len(raw_tgt[0])
        # 要保证底下这几个都是longTensor(长整数)
        # tgt = Variable(torch.from_numpy(np.array([[0]+[dict_spk2idx[spk] for spk in spks]+[dict_spk2idx['<EOS>']] for spks in raw_tgt],dtype=np.int))).transpose(0,1) #转换成数字,然后前后加开始和结束符号。
        tgt = Variable(torch.ones(
            top_k + 2, config.batch_size))  # 这里随便给一个tgt,为了测试阶段tgt的名字无所谓其实。

        src_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            mix_speech_len).unsqueeze(0)
        tgt_len = Variable(
            torch.LongTensor(config.batch_size).zero_() +
            len(eval_data['multi_spk_fea_list'][0])).unsqueeze(0)
        feas_tgt = models.rank_feas(raw_tgt,
                                    eval_data['multi_spk_fea_list'])  #这里是目标的图谱
        if config.WFM:
            tmp_size = feas_tgt.size()
            assert len(tmp_size) == 4
            feas_tgt_sum = torch.sum(feas_tgt, dim=1, keepdim=True)
            feas_tgt_sum_square = (feas_tgt_sum *
                                   feas_tgt_sum).expand(tmp_size)
            feas_tgt_square = feas_tgt * feas_tgt
            WFM_mask = feas_tgt_square / feas_tgt_sum_square

        if use_cuda:
            src = src.cuda().transpose(0, 1)
            tgt = tgt.cuda()
            src_len = src_len.cuda()
            tgt_len = tgt_len.cuda()
            feas_tgt = feas_tgt.cuda()
            if config.WFM:
                WFM_mask = WFM_mask.cuda()
        try:
            if 1 and len(opt.gpus) > 1:
                # samples, alignment = model.module.sample(src, src_len)
                samples, alignment, hiddens, predicted_masks = model.module.beam_sample(
                    src,
                    src_len,
                    dict_spk2idx,
                    tgt,
                    beam_size=config.beam_size)
            else:
                samples, alignment, hiddens, predicted_masks = model.beam_sample(
                    src,
                    src_len,
                    dict_spk2idx,
                    tgt,
                    beam_size=config.beam_size)
                # samples, alignment, hiddens, predicted_masks = model.beam_sample(src, src_len, dict_spk2idx, tgt, beam_size=config.beam_size)
        except Exception, info:
            print '**************Error occurs here************:', info
            continue

        if config.top1:
            predicted_masks = torch.cat([predicted_masks, 1 - predicted_masks],
                                        1)

        # '''
        # expand the raw mixed-features to topk channel.
        src = src.transpose(0, 1)
        siz = src.size()
        assert len(siz) == 3
        topk = feas_tgt.size()[1]
        x_input_map_multi = torch.unsqueeze(src,
                                            1).expand(siz[0], topk, siz[1],
                                                      siz[2])
        if config.WFM:
            feas_tgt = x_input_map_multi.data * WFM_mask
        if 1 and len(opt.gpus) > 1:
            ss_loss = model.module.separation_loss(x_input_map_multi,
                                                   predicted_masks, feas_tgt,
                                                   Var)
        else:
            ss_loss = model.separation_loss(x_input_map_multi, predicted_masks,
                                            feas_tgt, None)
        print 'loss for ss,this batch:', ss_loss.data[0]
        lera.log({
            'ss_loss_' + test_or_valid: ss_loss.data[0],
        })

        del ss_loss, hiddens

        # '''''
        if batch_idx <= (500 / config.batch_size
                         ):  #only the former batches counts the SDR
            predicted_maps = predicted_masks * x_input_map_multi
            # predicted_maps=Variable(feas_tgt)
            utils.bss_eval(config,
                           predicted_maps,
                           eval_data['multi_spk_fea_list'],
                           raw_tgt,
                           eval_data,
                           dst='batch_outputwaddd')
            del predicted_maps, predicted_masks, x_input_map_multi
            SDR_SUM = np.append(SDR_SUM, bss_test.cal('batch_outputwaddd/'))
            print 'SDR_aver_now:', SDR_SUM.mean()
            lera.log({'SDR sample': SDR_SUM.mean()})
            raw_input('Press any key to continue......')
        elif batch_idx == (500 / config.batch_size) + 1 and SDR_SUM.mean(
        ) > best_SDR:  #only record the best SDR once.
            print 'Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean())
            best_SDR = SDR_SUM.mean()
            # save_model(log_path+'checkpoint_bestSDR{}.pt'.format(best_SDR))

        # '''
        candidate += [
            convertToLabels(dict_idx2spk, s, dict_spk2idx['<EOS>'])
            for s in samples
        ]
        # source += raw_src
        reference += raw_tgt
        print 'samples:', samples
        print 'can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:],
                                        reference[-1 * config.batch_size:])
        alignments += [align for align in alignment]
        batch_idx += 1
コード例 #3
0
    print('loading checkpoint...\n', opt.restore)
    checkpoints = torch.load(opt.restore)

# cuda
use_cuda = torch.cuda.is_available() and len(opt.gpus) > 0
use_cuda = True
if use_cuda:
    torch.cuda.set_device(opt.gpus[0])
    torch.cuda.manual_seed(opt.seed)
print(use_cuda)

# data
print('loading data...\n')
start_time = time.time()

spk_global_gen = prepare_data(mode='global',
                              train_or_test='train')  #写一个假的数据生成,可以用来写模型先
global_para = spk_global_gen.next()
print global_para
spk_all_list, dict_spk2idx, dict_idx2spk, mix_speech_len, speech_fre, total_frames, spk_num_total, batch_total = global_para
config.speech_fre = speech_fre
config.mix_speech_len = mix_speech_len
del spk_global_gen
num_labels = len(spk_all_list)
print('loading the global setting cost: %.3f' % (time.time() - start_time))

if opt.pretrain:
    pretrain_embed = torch.load(config.emb_file)
else:
    pretrain_embed = None

# model