コード例 #1
0
def evaluate(model, dataset, batch_size=2, verbose=1, cal_sdr=False):
    total_SISNRi = 0
    total_SDRi = 0
    total_cnt = 0

    model.eval()
    model.to(device)

    data_loader = AudioDataLoader(dataset,
                                  batch_size=batch_size,
                                  shuffle=False)

    with torch.no_grad():
        for i, (audio, mixture_lengths) in enumerate(data_loader):
            # Get batch data
            padded_mixture = audio[:, 0]
            padded_source = audio[:, 1:]

            padded_mixture = padded_mixture.to(device)
            mixture_lengths = mixture_lengths.to(device)
            padded_source = padded_source.to(device)

            # Forward
            estimate_source = model(padded_mixture)  # [B, C, T]
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(padded_source, estimate_source, mixture_lengths)

            # Remove padding and flat
            mixture = remove_pad(padded_mixture, mixture_lengths)
            source = remove_pad(padded_source, mixture_lengths)
            # NOTE: use reorder estimate source
            estimate_source = remove_pad(reorder_estimate_source,
                                         mixture_lengths)
            # for each utterance
            for mix, src_ref, src_est in zip(mixture, source, estimate_source):
                if verbose == 1: print("Utt", total_cnt + 1)
                # Compute SDRi
                if cal_sdr:
                    avg_SDRi = cal_SDRi(src_ref, src_est, mix)
                    total_SDRi += avg_SDRi
                    if verbose == 1: print(f"\tSDRi={avg_SDRi:.{2}}")

                # Compute SI-SNRi
                avg_SISNRi = cal_SISNRi(src_ref, src_est, mix)
                if verbose == 1: print(f"\tSI-SNRi={avg_SISNRi:.{2}}")
                total_SISNRi += avg_SISNRi
                total_cnt += 1

    if cal_sdr:
        print(f"Average SDR improvement: {total_SDRi / total_cnt:.{2}}")
    print(f"Average SISNR improvement: {total_SISNRi / total_cnt:.{2}}")
コード例 #2
0
def evaluate(args):
    total_SISNRi = 0
    total_SDRi = 0
    total_cnt = 0

    # Load model
    model = ConvTasNet.load_model(args.model_path)
    model.eval()
    if args.use_cuda:
        model.cuda()

    # Load data
    dataset = AudioDataset(args.data_dir, args.batch_size,
                           sample_rate=args.sample_rate, segment=-1)
    data_loader = AudioDataLoader(dataset, batch_size=1, num_workers=2)

    with torch.no_grad():
        for i, (data) in enumerate(data_loader):
            # Get batch data
            padded_mixture, mixture_lengths, padded_source = data
            if args.use_cuda:
                padded_mixture = padded_mixture.cuda()
                mixture_lengths = mixture_lengths.cuda()
                padded_source = padded_source.cuda()
            # Forward
            estimate_source = model(padded_mixture)  # [B, C, T]
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(padded_source, estimate_source, mixture_lengths, args.pit)
            # Remove padding and flat
            mixture = remove_pad(padded_mixture, mixture_lengths)
            source = remove_pad(padded_source, mixture_lengths)
            # NOTE: use reorder estimate source
            estimate_source = remove_pad(reorder_estimate_source,
                                         mixture_lengths)
            # for each utterance
            for mix, src_ref, src_est in zip(mixture, source, estimate_source):
                # Compute SDRi
                if args.cal_sdr:
                    avg_SDRi = cal_SDRi(src_ref, src_est, mix)
                    total_SDRi += avg_SDRi
                    print("\tSDRi={0:.2f}".format(avg_SDRi))
                # Compute SI-SNRi
                avg_SISNRi = cal_SISNRi(src_ref, src_est, mix)
                print("\tSI-SNRi={0:.2f}".format(avg_SISNRi))
                total_SISNRi += avg_SISNRi
                total_cnt += 1
    if args.cal_sdr:
        print("Average SDR improvement: {0:.2f}".format(total_SDRi / total_cnt))
    print("Average SISNR improvement: {0:.2f}".format(total_SISNRi / total_cnt))
コード例 #3
0
def separate(args):
    if args.mix_dir is None and args.mix_json is None:
        print("Must provide mix_dir or mix_json! When providing mix_dir, "
              "mix_json is ignored.")

    # Load model
    model = ConvTasNet.load_model(args.model_path)
    print(model)
    model.eval()
    if args.use_cuda:
        model.cuda()

    # Load data
    eval_dataset = EvalDataset(args.mix_dir,
                               args.mix_json,
                               batch_size=args.batch_size,
                               sample_rate=args.sample_rate)
    eval_loader = EvalDataLoader(eval_dataset, batch_size=1)
    os.makedirs(args.out_dir, exist_ok=True)

    def write(inputs, filename, sr=args.sample_rate):
        librosa.output.write_wav(filename, inputs, sr, norm=True)

    with torch.no_grad():
        for (i, data) in enumerate(eval_loader):
            # Get batch data
            mixture, mix_lengths, filenames = data
            if args.use_cuda:
                mixture, mix_lengths = mixture.cuda(), mix_lengths.cuda()
            # Forward
            estimate_source = model(mixture)  # [B, C, T]
            # Remove padding and flat
            flat_estimate = remove_pad(estimate_source, mix_lengths)
            mixture = remove_pad(mixture, mix_lengths)
            # Write result
            for i, filename in enumerate(filenames):
                filename = os.path.join(
                    args.out_dir,
                    os.path.basename(filename).strip('.wav'))
                write(mixture[i], filename + '.wav')
                C = flat_estimate[i].shape[0]
                m = np.sum(np.abs(mixture[i]))
                for c in range(C):
                    p = np.sum(np.abs(flat_estimate[i][c]))
                    flat_estimate[i][c] = flat_estimate[i][c] / p * m
                    write(flat_estimate[i][c],
                          filename + '_s{}.wav'.format(c + 1))
コード例 #4
0
def predict_fn(input_data, model):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()
    with torch.no_grad():
        for (i, data) in enumerate(input_data): #eval_loader
            mixture, mix_lengths, filenames = data
            mixture, mix_lengths = mixture, mix_lengths#.cuda()
            # Forward
            estimate_source = model(mixture)  # [B, C, T]
            # Remove padding and flat
            flat_estimate = remove_pad(estimate_source, mix_lengths)
            mixture = remove_pad(mixture, mix_lengths)
            size = flat_estimate[0].shape #默认只有一个sample
            ept = np.zeros((size[0]+1,size[1]))
            ept[0,:]=mixture[0]
            ept[1:,:]=flat_estimate[0]
            
        return ept 
コード例 #5
0
def separate(args):
    if args.mix_dir is None and args.mix_json is None:
        print("Must provide mix_dir or mix_json! When providing mix_dir, "
              "mix_json is ignored.")

    # Load model
    model = FaSNet_base(enc_dim=256, feature_dim=64, hidden_dim=128, layer=6, segment_size=250, nspk=2, win_len=2)
    model_state = torch.load(args.model_path)
    model.load_state_dict(model_state)
    print(model)
    model.eval()
    if args.use_cuda:
        model.cuda()

    # Load data
    eval_dataset = EvalDataset(args.mix_dir, args.mix_json,
                               batch_size=args.batch_size,
                               sample_rate=args.sample_rate)
    eval_loader =  EvalDataLoader(eval_dataset, batch_size=1)
    os.makedirs(args.out_dir, exist_ok=True)

    def write(inputs, filename, sr=args.sample_rate):
        librosa.output.write_wav(filename, inputs, sr)# norm=True)

    with torch.no_grad():
        for (i, data) in enumerate(eval_loader):
            # Get batch data
            mixture, mix_lengths, filenames = data
            if args.use_cuda:
                mixture, mix_lengths = mixture.cuda(), mix_lengths.cuda()
            # Forward
            estimate_source = model(mixture)  # [B, C, T]
            # Remove padding and flat
            flat_estimate = remove_pad(estimate_source, mix_lengths)
            mixture = remove_pad(mixture, mix_lengths)
            # Write result
            for i, filename in enumerate(filenames):
                filename = os.path.join(args.out_dir,
                                        os.path.basename(filename).strip('.wav'))
                write(mixture[i], filename + '.wav')
                C = flat_estimate[i].shape[0]
                for c in range(C):
                    write(flat_estimate[i][c], filename + '_s{}.wav'.format(c+1))
コード例 #6
0
def evaluate(args):
    total_SISNRi = 0
    total_SDRi = 0
    total_cnt = 0

    # Load model
    model = DPTNet(args.N, args.C, args.L, args.H, args.K, args.B)

    if args.use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()

    # model.load_state_dict(torch.load(args.model_path, map_location='cpu'))

    model_info = torch.load(args.model_path)

    state_dict = OrderedDict()
    for k, v in model_info['model_state_dict'].items():
        name = k.replace("module.", "")  # remove 'module.'
        state_dict[name] = v
    model.load_state_dict(state_dict)

    print(model)

    # Load data
    dataset = AudioDataset(args.data_dir,
                           args.batch_size,
                           sample_rate=args.sample_rate,
                           segment=-1)
    data_loader = AudioDataLoader(dataset, batch_size=1, num_workers=2)

    with torch.no_grad():
        for i, (data) in enumerate(data_loader):
            # Get batch data
            padded_mixture, mixture_lengths, padded_source = data
            if args.use_cuda:
                padded_mixture = padded_mixture.cuda()
                mixture_lengths = mixture_lengths.cuda()
                padded_source = padded_source.cuda()
            # Forward
            estimate_source = model(padded_mixture)  # [B, C, T]
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(padded_source, estimate_source, mixture_lengths)
            # Remove padding and flat
            mixture = remove_pad(padded_mixture, mixture_lengths)
            source = remove_pad(padded_source, mixture_lengths)
            # NOTE: use reorder estimate source
            estimate_source = remove_pad(reorder_estimate_source,
                                         mixture_lengths)
            # for each utterance
            for mix, src_ref, src_est in zip(mixture, source, estimate_source):
                print("Utt", total_cnt + 1)
                # Compute SDRi
                if args.cal_sdr:
                    avg_SDRi = cal_SDRi(src_ref, src_est, mix)
                    total_SDRi += avg_SDRi
                    print("\tSDRi={0:.2f}".format(avg_SDRi))
                # Compute SI-SNRi
                avg_SISNRi = cal_SISNRi(src_ref, src_est, mix)
                print("\tSI-SNRi={0:.2f}".format(avg_SISNRi))
                total_SISNRi += avg_SISNRi
                total_cnt += 1
    if args.cal_sdr:
        print("Average SDR improvement: {0:.2f}".format(total_SDRi /
                                                        total_cnt))
    print("Average SISNR improvement: {0:.2f}".format(total_SISNRi /
                                                      total_cnt))
コード例 #7
0
ファイル: run.py プロジェクト: szxSpark/hit-cosem-2018
def train(args):
    logger = logging.getLogger("hit-cosem-2018")
    logger.info('Load data_set and vocab...')
    with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
        vocab = pickle.load(fin)
    atis_data = AtisDataSet(args.max_len, args.slot_name_file, train_files=args.train_files, test_files=args.test_files)
    logger.info('Converting text into ids...')
    atis_data.convert_to_ids(vocab)
    atis_data.dynamic_padding(vocab.token2id[vocab.pad_token])
    train_data, test_data = atis_data.get_numpy_data()
    train_data = AtisLoader(train_data)
    train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True)
    test_data = AtisLoader(test_data)
    test_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, shuffle=False)

    model = biLSTM(args, vocab.size(), len(atis_data.slot_names))
    optimizer = model.get_optimizer(args.learning_rate, args.embed_learning_rate, args.weight_decay)
    loss_fn = torch.nn.CrossEntropyLoss()

    score = []
    losses = []
    for eidx, _ in enumerate(range(args.epoch_num), 1):
        for bidx, data in enumerate(train_loader, 1):
            optimizer.zero_grad()
            sentences, labels = data
            if args.has_cuda:
                sentences, labels = Variable(sentences).cuda(), Variable(labels).cuda()
            else:
                sentences, labels = Variable(sentences), Variable(labels)
            output = model(sentences)
            labels = labels.view(-1)
            loss = loss_fn(output, labels)
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(output.data, 1)
            predicted = predicted.cpu().numpy()
            labels = labels.data.cpu().numpy()
            losses.append(loss.data[0])
            logger.info('epoch: {} batch: {} loss: {:.4f} acc: {:.4f}'.format(eidx, bidx, loss.data[0],
                                                                          accuracy(predicted, labels)))
        if eidx >= args.begin_epoch:
            if args.embed_learning_rate == 0:
                args.embed_learning_rate = 2e-4
            elif args.embed_learning_rate > 0:
                args.embed_learning_rate *= args.lr_decay
                if args.embed_learning_rate <= 1e-5:
                    args.embed_learning_rate = 1e-5
            args.learning_rate = args.learning_rate * args.lr_decay
            optimizer = model.get_optimizer(args.learning_rate,
                                            args.embed_learning_rate,
                                            args.weight_decay)

        logger.info('do eval on test set...')
        f = open("./result/result_epoch"+str(eidx)+".txt", 'w', encoding='utf-8')
        for data in test_loader:
            sentences, labels = data
            if args.has_cuda:
                sentences, labels = Variable(sentences).cuda(), Variable(labels).cuda()
            else:
                sentences, labels = Variable(sentences), Variable(labels)  # batch_size * max_len
            output = model(sentences)

            sentences = sentences.data.cpu().numpy().tolist()
            sentences = [vocab.recover_from_ids(remove_pad(s, vocab.token2id[vocab.pad_token]))
                         for s in sentences]
            labels = labels.data.cpu().numpy().tolist()
            _, predicted = torch.max(output.data, 1)
            predicted = predicted.view(-1, args.max_len).cpu().numpy().tolist()
            iter = [zip(s,
                        map(lambda x:atis_data.slot_names[x], labels[i][:len(s)]),
                               map(lambda x:atis_data.slot_names[x], predicted[i][:len(s)])
                        ) for i, s in enumerate(sentences)]
            for it in iter:
                for z in it:
                    z = list(map(str, z))
                    f.write(' '.join(z)+'\n')
        f.close()
        score.append(conlleval(eidx))
        torch.save(model.state_dict(), args.save_path+"biLSTM_epoch"+str(eidx)+".model")

    max_score_eidx = score.index(max(score))+1
    logger.info('epoch {} gets max score.'.format(max_score_eidx))
    os.system('perl ./eval/conlleval.pl < ./result/result_epoch' + str(max_score_eidx) + '.txt')
    
    x = [i + 1 for i in range(len(losses))]
    plt.plot(x, losses, 'r')
    plt.xlabel("time_step")
    plt.ylabel("loss")
    plt.title("CrossEntropyLoss")
    plt.show()
コード例 #8
0
def separate(args):
    if args.mix_dir is None and args.mix_json is None:
        print("Must provide mix_dir or mix_json! When providing mix_dir, "
              "mix_json is ignored.")

    # Load model
    model = ConvTasNet(256,
                       20,
                       256,
                       512,
                       3,
                       8,
                       4,
                       2,
                       norm_type="gLN",
                       causal=0,
                       mask_nonlinear="relu")
    model.cuda()
    model.load_state_dict(torch.load(args.model_path)['sep_state_dict'])
    print(model)
    model.eval()

    # Load data
    eval_dataset = EvalDataset(args.mix_dir,
                               args.mix_json,
                               batch_size=args.batch_size,
                               sample_rate=args.sample_rate)
    eval_loader = EvalDataLoader(eval_dataset, batch_size=1)
    os.makedirs(args.out_dir, exist_ok=True)

    def write(inputs, filename, sr=args.sample_rate):
        #librosa.output.write_wav(filename, inputs, sr)# norm=True)
        #librosa.output.write_wav(filename, inputs, sr, norm=True)
        #print(inputs)
        inputs = inputs / max(np.abs(inputs))
        #print(inputs)

        sf.write(filename, inputs, sr)
        #sf.write(filename, inputs, sr, 'PCM_16')

    with torch.no_grad():
        for (i, data) in enumerate(eval_loader):
            # Get batch data
            mixture, mix_lengths, filenames = data
            if args.use_cuda:
                mixture, mix_lengths = mixture.cuda(), mix_lengths.cuda()
            # Forward
            estimate_source = model(mixture)  # [B, C, T]
            # Remove padding and flat
            flat_estimate = remove_pad(estimate_source, mix_lengths)
            mixture = remove_pad(mixture, mix_lengths)
            # Write result
            for i, filename in enumerate(filenames):
                filename = os.path.join(
                    args.out_dir,
                    os.path.basename(filename).strip('.wav'))
                write(mixture[i], filename + '.wav')
                C = flat_estimate[i].shape[0]
                for c in range(C):
                    write(flat_estimate[i][c],
                          filename + '_s{}.wav'.format(c + 1))
コード例 #9
0
def evaluate(args):
    total_SISNRi = 0
    total_SDRi = 0
    total_cnt = 0
    avg_SISNRiPitNum = 0
    length = torch.ones(1)
    length = length.int()
    numberEsti =[]
    # Load model
    model = ConvTasNet.load_model(args.model_path)
 #   print(model)
    model.eval()
    if args.use_cuda:
        model.cuda(0)

    # Load data
    dataset = AudioDataset(args.data_dir, args.batch_size,
                           sample_rate=args.sample_rate, segment=2)
    data_loader = AudioDataLoader(dataset, batch_size=1, num_workers=2)

    with torch.no_grad():
        for i, (data) in enumerate(data_loader):
            print(i)
            # Get batch data
            padded_mixture, mixture_lengths, padded_source = data
            if args.use_cuda:
                padded_mixture = padded_mixture.cuda(0)
                mixture_lengths = mixture_lengths.cuda(0)
            # Forward
            estimate_source ,s_embed  = model(padded_mixture)  # [B, C, T],#[B,N,K,E] 
          #  print(estimate_source.shape)
           # embid = (model.separator.network[2][7])(padded_mixture)
          #  print(embid)
            '''
            embeddings = s_embed[0].data.cpu().numpy()
            embedding = (embeddings.reshape((1,-1,20)))[0]
            number = sourceNumEsti2(embedding)
            numberEsti.append(number)
            '''
           # print(estimate_source)
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(padded_source, estimate_source, mixture_lengths)
            # Remove padding and flat
            mixture = remove_pad(padded_mixture, mixture_lengths)
            source = remove_pad(padded_source, mixture_lengths)
           # print(max_snr.item())
            # NOTE: use reorder estimate source
            estimate_source = remove_pad(reorder_estimate_source,
                                         mixture_lengths)
           # print((estimate_source[0].shape))
            # for each utterance
            for mix, src_ref, src_est in zip(mixture, source, estimate_source):
                print("Utt", total_cnt + 1)
                # Compute SDRi
                if args.cal_sdr:
                    avg_SDRi = cal_SDRi(src_ref, src_est, mix)
                    total_SDRi += avg_SDRi
                    print("\tSDRi={0:.2f}".format(avg_SDRi))
                # Compute SI-SNRi
                avg_SISNRi = cal_SISNRi(src_ref, src_est, mix)
                #avg_SISNRiPit,a,b = cal_si_snr_with_pit(torch.from_numpy(src_ref), torch.from_numpy(src_est),length)
                print("\tSI-SNRi={0:.2f}".format(avg_SISNRi))
                total_SISNRi += (avg_SISNRi)
                #total_SNRiPitNum += avg_SISNRiPit.numpy()
                total_cnt += 1
            
    if args.cal_sdr:
        print("Average SDR improvement: {0:.2f}".format(total_SDRi / total_cnt))
    print("Average SISNR improvement: {0:.2f}".format(total_SISNRi / total_cnt))
    print("speaker:2 ./ClustertrainTFSE1New/final_paper_2_3_2chobatch6.pth.tar")
   
    return numberEsti