Exemplo n.º 1
0
    def test(self, args):
        with open(args.test_list, 'r') as test_list_file:
            self.test_list = [line.strip() for line in test_list_file.readlines()]
        self.model_name = args.model_name
        self.model_file = args.model_file
        self.test_mixture_path = args.test_mixture_path
        self.prediction_path = args.prediction_path

        # create a network
        print('model', self.model_name)
        net = Net(device=self.device, L=self.frame_size, width=self.width)
        # net = torch.nn.DataParallel(net)
        net.to(self.device)
        print('Number of learnable parameters: %d' % numParams(net))
        print(net)
        # loss and optimizer
        criterion = mse_loss()
        net.eval()
        print('Load model from "%s"' % self.model_file)
        checkpoint = Checkpoint()
        checkpoint.load(self.model_file)
        net.load_state_dict(checkpoint.state_dict)
        with torch.no_grad():
            for i in range(len(self.test_list)):
                # read the mixture for resynthesis
                filename_input = self.test_list[i].split('/')[-1]
                start1 = timeit.default_timer()
                print('{}/{}, Started working on {}.'.format(i + 1, len(self.test_list), self.test_list[i]))
                print('')
                filename_mix = filename_input.replace('.samp', '_mix.dat')

                filename_s_ideal = filename_input.replace('.samp', '_s_ideal.dat')
                filename_s_est = filename_input.replace('.samp', '_s_est.dat')
                # print(filename_mix)
                # sys.exit()
                f_mix = h5py.File(os.path.join(self.test_mixture_path, filename_mix), 'r')
                f_s_ideal = h5py.File(os.path.join(self.prediction_path, filename_s_ideal), 'w')
                f_s_est = h5py.File(os.path.join(self.prediction_path, filename_s_est), 'w')
                # create a test dataset
                testSet = EvalDataset(os.path.join(self.test_mixture_path, self.test_list[i]),
                                      self.num_test_sentences)

                # create a data loader for test
                test_loader = DataLoader(testSet,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=2,
                                         collate_fn=EvalCollate())

                # print '\n[%d/%d] Predict on %s' % (i+1, len(self.test_list), self.test_list[i])

                accu_test_loss = 0.0
                accu_test_nframes = 0

                ttime = 0.
                mtime = 0.
                cnt = 0.
                for k, (mix_raw, cln_raw) in enumerate(test_loader):
                    start = timeit.default_timer()
                    est_s = self.eval_forward(mix_raw, net)
                    est_s = est_s[:mix_raw.size]
                    mix = f_mix[str(k)][:]

                    ideal_s = cln_raw

                    f_s_ideal.create_dataset(str(k), data=ideal_s.astype(np.float32), chunks=True)
                    f_s_est.create_dataset(str(k), data=est_s.astype(np.float32), chunks=True)
                    # compute eval_loss

                    test_loss = np.mean((est_s - ideal_s) ** 2)

                    accu_test_loss += test_loss
                    cnt += 1
                    end = timeit.default_timer()
                    curr_time = end - start
                    ttime += curr_time
                    mtime = ttime / cnt
                    mtime = (mtime * (k) + (end - start)) / (k + 1)
                    print('{}/{}, test_loss = {:.4f}, time/utterance = {:.4f}, '
                          'mtime/utternace = {:.4f}'.format(k + 1, self.num_test_sentences, test_loss, curr_time,
                                                            mtime))

                avg_test_loss = accu_test_loss / cnt
                # bar.update(k,test_loss=avg_test_loss)
                # bar.finish()
                end1 = timeit.default_timer()
                print('********** Finisehe working on {}. time taken = {:.4f} **********'.format(filename_input,
                                                                                                 end1 - start1))
                print('')
                f_mix.close()
                f_s_est.close()
                f_s_ideal.close()
Exemplo n.º 2
0
    def train(self, args):
        with open(args.train_list, 'r') as train_list_file:
            self.train_list = [line.strip() for line in train_list_file.readlines()]
        self.eval_file = args.eval_file
        self.num_train_sentences = args.num_train_sentences
        self.batch_size = args.batch_size
        self.lr = args.lr
        self.max_epoch = args.max_epoch
        self.model_path = args.model_path
        self.log_path = args.log_path
        self.fig_path = args.fig_path
        self.eval_plot_num = args.eval_plot_num
        self.eval_steps = args.eval_steps
        self.resume_model = args.resume_model
        self.wav_path = args.wav_path
        self.train_wav_path = args.train_wav_path
        self.tool_path = args.tool_path

        # create a training dataset and an evaluation dataset
        trainSet = TrainingDataset(self.train_list,
                                   frame_size=self.frame_size,
                                   frame_shift=self.frame_shift)
        evalSet = EvalDataset(self.eval_file,
                              self.num_test_sentences)
        # trainSet = evalSet
        # create data loaders for training and evaluation
        train_loader = DataLoader(trainSet,
                                  batch_size=self.batch_size,
                                  shuffle=True,
                                  num_workers=16,
                                  collate_fn=TrainCollate())

        eval_loader = DataLoader(evalSet,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=EvalCollate())

        # create a network
        print('model', self.model_name)
        net = Net(device=self.device, L=self.frame_size, width=self.width)
        # net = torch.nn.DataParallel(net)
        net.to(self.device)
        print('Number of learnable parameters: %d' % numParams(net))
        print(net)

        criterion = mse_loss()
        criterion1 = stftm_loss(device=self.device)
        optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)
        self.lr_list = [0.0002] * 3 + [0.0001] * 6 + [0.00005] * 3 + [0.00001] * 3
        if self.resume_model:
            print('Resume model from "%s"' % self.resume_model)
            checkpoint = Checkpoint()
            checkpoint.load(self.resume_model)
            start_epoch = checkpoint.start_epoch
            start_iter = checkpoint.start_iter
            best_loss = checkpoint.best_loss
            net.load_state_dict(checkpoint.state_dict)
            optimizer.load_state_dict(checkpoint.optimizer)
        else:
            print('Training from scratch.')
            start_epoch = 0
            start_iter = 0
            best_loss = np.inf

        num_train_batches = self.num_train_sentences // self.batch_size
        total_train_batch = self.max_epoch * num_train_batches
        print('num_train_sentences', self.num_train_sentences)
        print('batches_per_epoch', num_train_batches)
        print('total_train_batch', total_train_batch)
        print('batch_size', self.batch_size)
        print('model_name', self.model_name)
        batch_timings = 0.
        counter = int(start_epoch * num_train_batches + start_iter)
        counter1 = 0
        print('counter', counter)
        ttime = 0.
        cnt = 0.
        iteration = 0
        print('best_loss', best_loss)
        for epoch in range(start_epoch, self.max_epoch):
            accu_train_loss = 0.0
            net.train()
            for param_group in optimizer.param_groups:
                param_group['lr'] = self.lr_list[epoch]

            start = timeit.default_timer()
            for i, (features, labels, nframes, feat_size, label_size, get_filename) in enumerate(
                    train_loader):  # features:torch.Size([4, 1, 250, 512])
                iteration += 1
                labels_cpu = labels
                i += start_iter
                features, labels = features.to(self.device), labels.to(self.device)  # torch.Size([4, 1, 250, 512])

                loss_mask = compLossMask(labels, nframes=nframes)

                # forward + backward + optimize
                optimizer.zero_grad()

                outputs = net(features)  # torch.Size([4, 1, 64256])

                feature_maker = Fbank(sample_rate=16000, n_fft=400, n_mels=40)
                loss_fbank = 0

                for t in range(len(get_filename)):
                    reader = h5py.File(get_filename[t], 'r')
                    feature_asr = reader['noisy_raw'][:]
                    label_asr = reader['clean_raw'][:]

                    feat_asr_size = int(feat_size[t][0].item())
                    label_asr_size = int(label_size[t][0].item())

                    output_asr = self.train_asr_forward(feature_asr, net)
                    est_output_asr = output_asr[:feat_asr_size]
                    ideal_labels_asr = label_asr

                    # 保存train的wav
                    est_path = os.path.join(self.train_wav_path, '{}_est.wav'.format(t + 1))
                    ideal_path = os.path.join(self.train_wav_path, '{}_ideal.wav'.format(t + 1))
                    sf.write(est_path, normalize_wav(est_output_asr)[0], self.srate)
                    sf.write(ideal_path, normalize_wav(ideal_labels_asr)[0], self.srate)

                    # read wav
                    est_sig = sb.dataio.dataio.read_audio(est_path).unsqueeze(axis=0).to(self.device)
                    ideal_sig = sb.dataio.dataio.read_audio(ideal_path).unsqueeze(axis=0).to(self.device)
                    est_sig_feats = feature_maker(est_sig)
                    ideal_sig_feats = feature_maker(ideal_sig)

                    # fbank_loss
                    loss_fbank += F.mse_loss(est_sig_feats, ideal_sig_feats, True)

                loss_fbank /= 100 * len(get_filename)
                # print(loss_fbank)
                # loss_fbank = 1 / (1 + math.exp(loss_fbank))

                outputs = outputs[:, :, :labels.shape[-1]]

                loss1 = criterion(outputs, labels, loss_mask, nframes)
                loss2 = criterion1(outputs, labels, loss_mask, nframes)
                # print(loss1)
                # print(loss2)

                # loss = 0.8 * loss1 + 0.2 * loss2
                loss = 0.4 * loss1 + 0.1 * loss2 + 0.5 * loss_fbank

                loss.backward()
                optimizer.step()
                # calculate losses
                running_loss = loss.data.item()
                accu_train_loss += running_loss

                # train-loss show
                summary.add_scalar('Train Loss', accu_train_loss, iteration)

                cnt += 1.
                counter += 1
                counter1 += 1

                del loss, loss_fbank, loss1, loss2, outputs, loss_mask, features, labels
                end = timeit.default_timer()
                curr_time = end - start
                ttime += curr_time
                mtime = ttime / counter1
                print(
                    'iter = {}/{}, epoch = {}/{}, loss = {:.5f}, time/batch = {:.5f}, mtime/batch = {:.5f}'.format(
                        i + 1,
                        num_train_batches, epoch + 1, self.max_epoch, running_loss, curr_time, mtime))
                start = timeit.default_timer()
                if (i + 1) % self.eval_steps == 0:
                    start = timeit.default_timer()

                    avg_train_loss = accu_train_loss / cnt

                    avg_eval_loss = self.validate(net, eval_loader, iteration)

                    net.train()

                    print('Epoch [%d/%d], Iter [%d/%d]  ( TrainLoss: %.4f | EvalLoss: %.4f )' % (
                        epoch + 1, self.max_epoch, i + 1, self.num_train_sentences // self.batch_size,
                        avg_train_loss,
                        avg_eval_loss))

                    is_best = True if avg_eval_loss < best_loss else False
                    best_loss = avg_eval_loss if is_best else best_loss

                    checkpoint = Checkpoint(epoch, i, avg_train_loss, avg_eval_loss, best_loss, net.state_dict(),
                                            optimizer.state_dict())

                    model_name = self.model_name + '_latest.model'
                    best_model = self.model_name + '_best.model'
                    checkpoint.save(is_best, os.path.join(self.model_path, model_name),
                                    os.path.join(self.model_path, best_model))

                    logging(self.log_path, self.model_name + '_loss_log.txt', checkpoint, self.eval_steps)
                    # metric_logging(self.log_path, self.model_name +'_metric_log.txt', epoch+1, [avg_st, avg_sn, avg_pe])
                    accu_train_loss = 0.0
                    cnt = 0.

                    net.train()
                if (i + 1) % num_train_batches == 0:
                    break

        avg_st, avg_sn, avg_pe = self.validate_with_metrics(net, eval_loader)
        net.train()
        print('#' * 50)
        print('')
        print('After {} epoch the performance on validation score is a s follows:'.format(epoch + 1))
        print('')
        print('STOI: {:.4f}'.format(avg_st))
        print('SNR: {:.4f}'.format(avg_sn))
        print('PESQ: {:.4f}'.format(avg_pe))
        for param_group in optimizer.param_groups:
            print('learning_rate', param_group['lr'])
        print('')
        print('#' * 50)
        checkpoint = Checkpoint(epoch, 0, None, None, best_loss, net.state_dict(), optimizer.state_dict())
        checkpoint.save(False, os.path.join(self.model_path, self.model_name + '-{}.model'.format(epoch + 1)),
                        os.path.join(self.model_path, best_model))
        metric_logging(self.log_path, self.model_name + '_metric_log.txt', epoch, [avg_st, avg_sn, avg_pe])
        start_iter = 0.
Exemplo n.º 3
0
    def train(self, args):
        with open(args.train_list, 'r') as train_list_file:
            self.train_list = [
                line.strip() for line in train_list_file.readlines()
            ]
        self.eval_file = args.eval_file
        self.num_train_sentences = args.num_train_sentences
        self.batch_size = args.batch_size
        self.lr = args.lr
        self.max_epoch = args.max_epoch
        self.model_path = args.model_path
        self.log_path = args.log_path
        self.fig_path = args.fig_path
        self.eval_plot_num = args.eval_plot_num
        self.eval_steps = args.eval_steps
        self.resume_model = args.resume_model
        self.wav_path = args.wav_path
        self.tool_path = args.tool_path

        # create a training dataset and an evaluation dataset
        trainSet = TrainingDataset(self.train_list,
                                   frame_size=self.frame_size,
                                   frame_shift=self.frame_shift)
        evalSet = EvalDataset(self.eval_file, self.num_test_sentences)
        #trainSet = evalSet
        # create data loaders for training and evaluation
        train_loader = DataLoader(trainSet,
                                  batch_size=self.batch_size,
                                  shuffle=True,
                                  num_workers=16,
                                  collate_fn=TrainCollate())
        eval_loader = DataLoader(evalSet,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=EvalCollate())

        # create a network
        print('model', self.model_name)
        net = Net(device=self.device, L=self.frame_size, width=self.width)
        #net = torch.nn.DataParallel(net)
        net.to(self.device)
        print('Number of learnable parameters: %d' % numParams(net))
        print(net)

        criterion = mse_loss()
        criterion1 = stftm_loss(device=self.device)
        optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)
        self.lr_list = [0.0002] * 3 + [0.0001] * 6 + [0.00005] * 3 + [0.00001
                                                                      ] * 3
        if self.resume_model:
            print('Resume model from "%s"' % self.resume_model)
            checkpoint = Checkpoint()
            checkpoint.load(self.resume_model)
            start_epoch = checkpoint.start_epoch
            start_iter = checkpoint.start_iter
            best_loss = checkpoint.best_loss
            net.load_state_dict(checkpoint.state_dict)
            optimizer.load_state_dict(checkpoint.optimizer)
        else:
            print('Training from scratch.')
            start_epoch = 0
            start_iter = 0
            best_loss = np.inf

        num_train_batches = self.num_train_sentences // self.batch_size
        total_train_batch = self.max_epoch * num_train_batches
        print('num_train_sentences', self.num_train_sentences)
        print('batches_per_epoch', num_train_batches)
        print('total_train_batch', total_train_batch)
        print('batch_size', self.batch_size)
        print('model_name', self.model_name)
        batch_timings = 0.
        counter = int(start_epoch * num_train_batches + start_iter)
        counter1 = 0
        print('counter', counter)
        ttime = 0.
        cnt = 0.
        print('best_loss', best_loss)
        for epoch in range(start_epoch, self.max_epoch):
            accu_train_loss = 0.0
            net.train()
            for param_group in optimizer.param_groups:
                param_group['lr'] = self.lr_list[epoch]

            start = timeit.default_timer()
            for i, (features, labels, nframes) in enumerate(train_loader):
                i += start_iter
                features, labels = features.to(self.device), labels.to(
                    self.device)

                loss_mask = compLossMask(labels, nframes=nframes)

                # forward + backward + optimize
                optimizer.zero_grad()

                outputs = net(features)
                outputs = outputs[:, :, :labels.shape[-1]]

                loss1 = criterion(outputs, labels, loss_mask, nframes)
                loss2 = criterion1(outputs, labels, loss_mask, nframes)

                loss = 0.8 * loss1 + 0.2 * loss2
                loss.backward()
                optimizer.step()
                # calculate losses
                running_loss = loss.data.item()
                accu_train_loss += running_loss

                cnt += 1.
                counter += 1
                counter1 += 1

                del loss, loss1, loss2, outputs, loss_mask, features, labels
                end = timeit.default_timer()
                curr_time = end - start
                ttime += curr_time
                mtime = ttime / counter1
                print(
                    'iter = {}/{}, epoch = {}/{}, loss = {:.5f}, time/batch = {:.5f}, mtime/batch = {:.5f}'
                    .format(i + 1, num_train_batches, epoch + 1,
                            self.max_epoch, running_loss, curr_time, mtime))
                start = timeit.default_timer()
                if (i + 1) % self.eval_steps == 0:
                    start = timeit.default_timer()

                    avg_train_loss = accu_train_loss / cnt

                    avg_eval_loss = self.validate(net, eval_loader)

                    net.train()

                    print(
                        'Epoch [%d/%d], Iter [%d/%d]  ( TrainLoss: %.4f | EvalLoss: %.4f )'
                        % (epoch + 1, self.max_epoch, i + 1,
                           self.num_train_sentences // self.batch_size,
                           avg_train_loss, avg_eval_loss))

                    is_best = True if avg_eval_loss < best_loss else False
                    best_loss = avg_eval_loss if is_best else best_loss

                    checkpoint = Checkpoint(epoch, i, avg_train_loss,
                                            avg_eval_loss, best_loss,
                                            net.state_dict(),
                                            optimizer.state_dict())

                    model_name = self.model_name + '_latest.model'
                    best_model = self.model_name + '_best.model'
                    checkpoint.save(is_best,
                                    os.path.join(self.model_path, model_name),
                                    os.path.join(self.model_path, best_model))

                    logging(self.log_path, self.model_name + '_loss_log.txt',
                            checkpoint, self.eval_steps)
                    #metric_logging(self.log_path, self.model_name +'_metric_log.txt', epoch+1, [avg_st, avg_sn, avg_pe])
                    accu_train_loss = 0.0
                    cnt = 0.

                    net.train()
                if (i + 1) % num_train_batches == 0:
                    break

            avg_st, avg_sn, avg_pe = self.validate_with_metrics(
                net, eval_loader)
            net.train()
            print('#' * 50)
            print('')
            print(
                'After {} epoch the performance on validation score is a s follows:'
                .format(epoch + 1))
            print('')
            print('STOI: {:.4f}'.format(avg_st))
            print('SNR: {:.4f}'.format(avg_sn))
            print('PESQ: {:.4f}'.format(avg_pe))
            for param_group in optimizer.param_groups:
                print('learning_rate', param_group['lr'])
            print('')
            print('#' * 50)
            checkpoint = Checkpoint(epoch, 0, None, None, best_loss,
                                    net.state_dict(), optimizer.state_dict())
            checkpoint.save(
                False,
                os.path.join(self.model_path,
                             self.model_name + '-{}.model'.format(epoch + 1)),
                os.path.join(self.model_path, best_model))
            metric_logging(self.log_path, self.model_name + '_metric_log.txt',
                           epoch, [avg_st, avg_sn, avg_pe])
            start_iter = 0.