Exemplo n.º 1
0
def main():
    _, tt_set, feat_dim, vocab_size, tokenizer, msg = load_dataset(
        args.njobs, args.gpu, args.pin_memory,
        config['hparas']['curriculum'] > 0, **config['data'])
    verbose(msg)

    # model select
    print('Model initializing\n')
    net = torch.nn.DataParallel(
        AttentionModel(120,
                       hidden_size=args.hidden_size,
                       dropout_p=args.dropout_p,
                       use_attn=args.attn_use,
                       stacked_encoder=args.stacked_encoder,
                       attn_len=args.attn_len))
    # net = AttentionModel(257, 112, dropout_p = args.dropout_p, use_attn = args.attn_use)
    net = net.cuda()
    print(net)

    optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)

    # Check point load
    print('Trying Checkpoint Load\n')
    ckpt_dir = args.ck_dir
    ckpt_path = os.path.join(ckpt_dir, args.ck_name)

    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path)
        try:
            net.load_state_dict(ckpt['model'])
            optimizer.load_state_dict(ckpt['optimizer'])
            print('checkpoint is loaded !')
        except RuntimeError as e:
            print('wrong checkpoint\n')
    else:
        print('checkpoint not exist!')

    # test phase
    net.eval()
    with torch.no_grad():
        sr1 = 16000
        window_size = 25  # int, window size for FFT (ms)
        stride = 10
        ws = int(sr1 * 0.001 * window_size)
        st = int(sr1 * 0.001 * stride)
        name_sum = 0
        for input in tqdm(tt_set):
            tt_noisy_set, feat_dim, vocab_size, tokenizer = load_noisy_dataset(
                "test", input[0], args.njobs, args.gpu, args.pin_memory,
                config['hparas']['curriculum'] > 0, **config['data_noisy'])
            for input_noisy in tt_noisy_set:
                # test_clean_feat = input[1].to(device='cuda')
                test_noisy_feat = input_noisy[1].to(device='cuda')
                # feed data
                test_mixed_feat, attn_weight = net(test_noisy_feat)

                for i in range(len(test_mixed_feat)):
                    name = args.out_path + input_noisy[0][i] + '.mat'
                    feat = test_mixed_feat[i].to(device='cpu').numpy()
                    scio.savemat(name, {'feat': feat})
Exemplo n.º 2
0
    def _build_graph(self):
        with tf.variable_scope("asr"):
            loss_asr = AttentionModel._build_graph(self)

        with tf.variable_scope("da_recog"):
            da_inputs, da_input_len = self.get_da_inputs()

            history_targets, history_target_seq_len, history_seq_len = self._build_history(
                da_inputs, da_input_len, rank=1, dtype=tf.float32)

            history_inputs = self._build_word_encoder(
                history_targets,
                history_target_seq_len,
            )

            encoded_history = self._build_utt_encoder(history_inputs,
                                                      history_seq_len)

            loss_da, self.predicted_da_labels = self._get_loss(encoded_history)
            with tf.control_dependencies([loss_da]):
                self.update_prev_inputs = self._build_update_prev_inputs(
                    da_inputs, da_input_len)

        if loss_asr == 0.0:
            loss = loss_da
        else:
            loss = self.hparams.da_attention_lambda * loss_asr + (
                1 - self.hparams.da_attention_lambda) * loss_da
        return loss
Exemplo n.º 3
0
def main():
    summary = SummaryWriter('./log')
    # os.system('tensorboard --logdir=log')
    #
    # set Hyper parameter
    # json_path = os.path.join(args.model_dir)
    # params = train_utils.Params(json_path)

    # data loader
    train_dataset = AudioDataset(data_type='train')
    # modify:num_workers=4
    train_data_loader = DataLoader(dataset=train_dataset,
                                   batch_size=args.batch_size,
                                   collate_fn=train_dataset.collate,
                                   shuffle=True,
                                   num_workers=4)
    test_dataset = AudioDataset(data_type='valid')
    test_data_loader = DataLoader(dataset=test_dataset,
                                  batch_size=args.batch_size,
                                  collate_fn=test_dataset.collate,
                                  shuffle=False,
                                  num_workers=4)

    # # data loader
    # train_dataset = AudioDataset(data_type='test')
    # # modify:num_workers=4
    # train_data_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, collate_fn=train_dataset.collate,
    #                                shuffle=True, num_workers=0)
    # test_dataset = AudioDataset(data_type='test')
    # test_data_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, collate_fn=test_dataset.collate,
    #                               shuffle=False, num_workers=0)
    # model select
    print('Model initializing\n')
    net = torch.nn.DataParallel(
        AttentionModel(257,
                       hidden_size=args.hidden_size,
                       dropout_p=args.dropout_p,
                       use_attn=args.attn_use,
                       stacked_encoder=args.stacked_encoder,
                       attn_len=args.attn_len))
    # net = AttentionModel(257, 112, dropout_p = args.dropout_p, use_attn = arg0s.attn_use)
    net = net.cuda()
    print(net)

    optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)

    scheduler = ExponentialLR(optimizer, 0.5)

    # check point load
    # Check point load

    print('Trying Checkpoint Load\n')
    # ckpt_dir = 'ckpt_dir_stoi'
    ckpt_dir = 'ckpt_dir'
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    best_PESQ = 0.
    best_STOI = 0.
    best_loss = 200000.
    ckpt_path = os.path.join(ckpt_dir, args.ck_name)
    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path)
        try:
            net.load_state_dict(ckpt['model'])
            optimizer.load_state_dict(ckpt['optimizer'])
            best_loss = ckpt['best_loss']

            print('checkpoint is loaded !')
            print('current best loss : %.4f' % best_loss)
        except RuntimeError as e:
            print('wrong checkpoint\n')
    else:
        print('checkpoint not exist!')
        print('current best loss : %.4f' % best_loss)

    print('Training Start!')
    # train
    iteration = 0
    train_losses = []
    test_losses = []
    for epoch in range(args.num_epochs):
        train_bar = tqdm(train_data_loader)
        # train_bar = train_data_loader876\
        n = 0
        avg_loss = 0
        avg_pesq = 0
        avg_stoi = 0
        net.train()
        for input in train_bar:
            iteration += 1
            # load data
            train_mixed, train_clean, seq_len = map(lambda x: x.cuda(), input)

            mixed = stft(train_mixed)
            cleaned = stft(train_clean)
            mixed = mixed.transpose(1, 2)
            cleaned = cleaned.transpose(1, 2)
            real, imag = mixed[..., 0], mixed[..., 1]
            clean_real, clean_imag = cleaned[..., 0], cleaned[..., 1]
            mag = torch.sqrt(real**2 + imag**2)
            clean_mag = torch.sqrt(clean_real**2 + clean_imag**2)
            phase = torch.atan2(imag, real)

            # feed data
            out_mag, attn_weight = net(mag)
            out_real = out_mag * torch.cos(phase)
            out_imag = out_mag * torch.sin(phase)
            out_real, out_imag = torch.squeeze(out_real,
                                               1), torch.squeeze(out_imag, 1)
            out_real = out_real.transpose(1, 2)
            out_imag = out_imag.transpose(1, 2)

            out_audio = istft(out_real, out_imag, train_mixed.size(1))
            out_audio = torch.squeeze(out_audio, dim=1)
            for i, l in enumerate(seq_len):
                out_audio[i, l:] = 0

            loss = 0
            PESQ = 0
            STOI = 0
            origin_PESQ = 0
            origin_STOI = 0

            loss = F.mse_loss(out_mag, clean_mag, True)
            if torch.any(torch.isnan(loss)):
                torch.save(
                    {
                        'clean_mag': clean_mag,
                        'out_mag': out_mag,
                        'mag': mag
                    }, 'nan_mag')
                raise ('loss is NaN')
            avg_loss += loss.item()
            n += 1
            # gradient optimizer
            optimizer.zero_grad()

            # backpropagate LOSS20+

            loss.backward()

            # update weight
            optimizer.step()

        avg_loss /= n
        avg_pesq /= n
        avg_stoi /= n
        print('result:')
        print(
            '[epoch: {}, iteration: {}] avg_loss : {:.4f} avg_pesq : {:.4f} avg_stoi : {:.4f} '
            .format(epoch, iteration, avg_loss, avg_pesq, avg_stoi))

        summary.add_scalar('Train Loss', avg_loss, iteration)

        train_losses.append(avg_loss)
        if (len(train_losses) > 2) and (train_losses[-2] < avg_loss):
            print("Learning rate Decay")
            scheduler.step()

        # test phase
        n = 0
        avg_test_loss = 0
        avg_test_pesq = 0
        avg_test_stoi = 0
        test_bar = tqdm(test_data_loader)

        net.eval()
        with torch.no_grad():
            for input in test_bar:
                test_mixed, test_clean, seq_len = map(lambda x: x.cuda(),
                                                      input)
                mixed = stft(test_mixed)
                cleaned = stft(test_clean)
                mixed = mixed.transpose(1, 2)
                cleaned = cleaned.transpose(1, 2)
                real, imag = mixed[..., 0], mixed[..., 1]
                clean_real, clean_imag = cleaned[..., 0], cleaned[..., 1]
                mag = torch.sqrt(real**2 + imag**2)
                clean_mag = torch.sqrt(clean_real**2 + clean_imag**2)
                phase = torch.atan2(imag, real)

                logits_mag, logits_attn_weight = net(mag)
                logits_real = logits_mag * torch.cos(phase)
                logits_imag = logits_mag * torch.sin(phase)
                logits_real, logits_imag = torch.squeeze(logits_real,
                                                         1), torch.squeeze(
                                                             logits_imag, 1)
                logits_real = logits_real.transpose(1, 2)
                logits_imag = logits_imag.transpose(1, 2)

                logits_audio = istft(logits_real, logits_imag,
                                     test_mixed.size(1))
                logits_audio = torch.squeeze(logits_audio, dim=1)
                for i, l in enumerate(seq_len):
                    logits_audio[i, l:] = 0

                test_PESQ = 0
                test_STOI = 0

                test_loss = F.mse_loss(logits_mag, clean_mag, True)

                avg_test_loss += test_loss.item()
                n += 1

            avg_test_loss /= n
            avg_test_pesq /= n
            avg_test_stoi /= n

            test_losses.append(avg_test_loss)
            summary.add_scalar('Test Loss', avg_test_loss, iteration)

            print(
                '[epoch: {}, iteration: {}] test loss : {:.4f} avg_test_pesq : {:.4f} avg_test_stoi : {:.4f}'
                .format(epoch, iteration, avg_test_loss, avg_test_pesq,
                        avg_test_stoi))
            if avg_test_loss < best_loss:
                best_PESQ = test_PESQ
                best_STOI = test_STOI
                best_loss = avg_test_loss
                # Note: optimizer also has states ! don't forget to save them as well.
                ckpt = {
                    'model': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_loss': best_loss
                }
                torch.save(ckpt, ckpt_path)
                print('checkpoint is saved !')
Exemplo n.º 4
0
 def __init__(self):
     AttentionModel.__init__(self, force_alignment_history=True)
Exemplo n.º 5
0
def main(): 
    test_dataset = AudioDataset(data_type=args.test_set)
    test_data_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, collate_fn=test_dataset.collate, shuffle=False, num_workers=4)
    
    
    #model select
    print('Model initializing\n')
    net = torch.nn.DataParallel(AttentionModel(257, hidden_size = args.hidden_size, dropout_p = args.dropout_p, use_attn = args.attn_use, stacked_encoder = args.stacked_encoder, attn_len = args.attn_len))
    #net = AttentionModel(257, 112, dropout_p = args.dropout_p, use_attn = args.attn_use)
    net = net.cuda()
    print(net)

    optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)

    #Check point load

    print('Trying Checkpoint Load\n')
    ckpt_dir = args.ck_dir
    best_PESQ = 0.
    best_STOI = 0.
    ckpt_path = os.path.join(ckpt_dir, args.ck_name)


    if os.path.exists(ckpt_path):
    	ckpt = torch.load(ckpt_path)
    	try:
       	    net.load_state_dict(ckpt['model'])
            optimizer.load_state_dict(ckpt['optimizer'])
            best_STOI = ckpt['best_STOI']

            print('checkpoint is loaded !')
            print('current best loss : %.4f' % best_loss)
    	except RuntimeError as e:
            print('wrong checkpoint\n')
    else:    
        print('checkpoint not exist!')
        print('current best loss : %.4f' % best_loss)

    #test phase
    n = 0
    avg_test_loss = 0

    
    net.eval()
    with torch.no_grad():
        test_bar = tqdm(test_data_loader)
        for input in test_bar:
            test_mixed, test_clean, seq_len = map(lambda x: x.cuda(), input)
            mixed = stft(test_mixed)
            cleaned = stft(test_clean)
            mixed = mixed.transpose(1,2)
            cleaned = cleaned.transpose(1,2)
            real, imag = mixed[..., 0], mixed[..., 1]
            clean_real, clean_imag = cleaned[..., 0], cleaned[..., 1]
            mag = torch.sqrt(real**2 + imag**2)
            clean_mag = torch.sqrt(clean_real**2 + clean_imag**2)
            phase = torch.atan2(imag, real)
        
            logits_mag, logits_attn_weight = net(mag)
            logits_real = logits_mag * torch.cos(phase)
            logits_imag = logits_mag * torch.sin(phase)
            logits_real, logits_imag = torch.squeeze(logits_real, 1), torch.squeeze(logits_imag, 1)
            logits_real = logits_real.transpose(1,2)
            logits_imag = logits_imag.transpose(1,2)
            
            logits_audio = istft(logits_real, logits_imag, test_mixed.size(1))
            logits_audio = torch.squeeze(logits_audio, dim=1)
            for i, l in enumerate(seq_len):
                logits_audio[i, l:] = 0
        
            test_loss = 0
            test_PESQ = 0
            test_STOI = 0
        
            test_loss = F.mse_loss(logits_mag, clean_mag, True)



            for i in range(len(test_mixed)):


                librosa.output.write_wav('test_out.wav', logits_audio[i].cpu().data.numpy()[:seq_len[i].cpu().data.numpy()], 16000)
                cur_PESQ = pesq(test_clean[i].detach().cpu().numpy(), logits_audio[i].detach().cpu().numpy(), 16000)
                cur_STOI = stoi(test_clean[i].detach().cpu().numpy(), logits_audio[i].detach().cpu().numpy(), 16000, extended=False)
        
                test_PESQ += cur_PESQ
                test_STOI += cur_STOI

            test_PESQ /= len(test_mixed)
            test_STOI /= len(test_mixed)	
            avg_test_loss += test_loss
            n += 1

            #test accuracy
            #test_pesq = pesq('test_clean.wav', 'test_out.wav', 16000)
            #test_stoi = stoi('test_clean.wav', 'test_out.wav', 16000)

        avg_test_loss /= n
            #summary.add_scalar('Test Loss', avg_test_loss.item(), iteration)
        print('test loss : {:.4f} PESQ : {:.4f} STOI : {:.4f}'.format(avg_test_loss, test_PESQ, test_STOI))
Exemplo n.º 6
0
def main():
    test_dataset = AudioDataset(data_type=args.test_set)
    test_data_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, collate_fn=test_dataset.collate,
                                  shuffle=False, num_workers=multiprocessing.cpu_count())

    # model select
    print('Model initializing\n')
    net = torch.nn.DataParallel(
        AttentionModel(257, hidden_size=args.hidden_size, dropout_p=args.dropout_p, use_attn=args.attn_use,
                       stacked_encoder=args.stacked_encoder, attn_len=args.attn_len))
    # net = AttentionModel(257, 112, dropout_p = args.dropout_p, use_attn = args.attn_use)
    net = net.cuda()
    print(net)

    optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)

    # Check point load

    print('Trying Checkpoint Load\n')
    ckpt_dir = args.ck_dir
    best_PESQ = 0.
    best_STOI = 0.
    ckpt_path = os.path.join(ckpt_dir, args.ck_name)

    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path)
        try:
            net.load_state_dict(ckpt['model'])
            optimizer.load_state_dict(ckpt['optimizer'])
            # best_STOI = ckpt['best_STOI']

            print('checkpoint is loaded !')
            # print('current best loss : %.4f' % best_loss)
        except RuntimeError as e:
            print('wrong checkpoint\n')
    else:
        print('checkpoint not exist!')
        # print('current best loss : %.4f' % best_loss)

    # test phase
    n = 0
    name_sum = 0
    avg_test_loss = 0

    net.eval()
    with torch.no_grad():
        test_bar = tqdm(test_data_loader)
        for input in test_bar:
            test_mixed, test_clean, seq_len = map(lambda x: x.cuda(), input)
            mixed = stft(test_mixed)
            cleaned = stft(test_clean)
            mixed = mixed.transpose(1, 2)
            cleaned = cleaned.transpose(1, 2)
            real, imag = mixed[..., 0], mixed[..., 1]
            clean_real, clean_imag = cleaned[..., 0], cleaned[..., 1]
            mag = torch.sqrt(real ** 2 + imag ** 2)
            clean_mag = torch.sqrt(clean_real ** 2 + clean_imag ** 2)
            phase = torch.atan2(imag, real)

            logits_mag, logits_attn_weight = net(mag)
            logits_real = logits_mag * torch.cos(phase)
            logits_imag = logits_mag * torch.sin(phase)
            logits_real, logits_imag = torch.squeeze(logits_real, 1), torch.squeeze(logits_imag, 1)
            logits_real = logits_real.transpose(1, 2)
            logits_imag = logits_imag.transpose(1, 2)

            logits_audio = istft(logits_real, logits_imag, test_mixed.size(1))
            logits_audio = torch.squeeze(logits_audio, dim=1)
            for i, l in enumerate(seq_len):
                logits_audio[i, l:] = 0

            test_loss = 0
            test_PESQ = 0
            test_STOI = 0
            origin_test_PESQ = 0
            origin_test_STOI = 0

            test_loss1 = F.mse_loss(logits_mag, clean_mag, True)

            # other_loss:
            wav_test_mix = logits_audio
            wav_test_clean = test_clean
            # fbank:
            fbank_wav_test_mix = torchaudio.compliance.kaldi.fbank(wav_test_mix)
            fbank_wav_test_clean = torchaudio.compliance.kaldi.fbank(wav_test_clean)
            test_loss2 = F.mse_loss(fbank_wav_test_mix, fbank_wav_test_clean, True)
            # mfcc:
            mfcc_wav_test_mix = torchaudio.compliance.kaldi.mfcc(wav_test_mix)
            mfcc_wav_test_clean = torchaudio.compliance.kaldi.mfcc(wav_test_clean)
            test_loss3 = F.mse_loss(mfcc_wav_test_mix, mfcc_wav_test_clean, True)

            test_loss = test_loss1 + test_loss2 + test_loss3

            for i in range(len(test_mixed)):
                name = str(test_dataset.file_names[name_sum]).split('/', 7)[7]
                librosa.output.write_wav(args.out_path + name,
                                         logits_audio[i].cpu().data.numpy()[:seq_len[i].cpu().data.numpy()],
                                         16000)
                cur_PESQ = pesq(test_clean[i].detach().cpu().numpy(), logits_audio[i].detach().cpu().numpy(), 16000)
                cur_STOI = stoi(test_clean[i].detach().cpu().numpy(), logits_audio[i].detach().cpu().numpy(), 16000,
                                extended=False)
                origin_cur_PESQ = pesq(test_clean[i].detach().cpu().numpy(), test_mixed[i].detach().cpu().numpy(),
                                       16000)
                origin_cur_STOI = stoi(test_clean[i].detach().cpu().numpy(), test_mixed[i].detach().cpu().numpy(),
                                       16000,
                                       extended=False)

                test_PESQ += cur_PESQ
                test_STOI += cur_STOI
                origin_test_PESQ += origin_cur_PESQ
                origin_test_STOI += origin_cur_STOI

                name_sum += 1

            test_PESQ /= len(test_mixed)
            test_STOI /= len(test_mixed)
            origin_test_PESQ /= len(test_mixed)
            origin_test_STOI /= len(test_mixed)
            avg_test_loss += test_loss
            n += 1

        avg_test_loss /= n
        print(
            'test loss : {:.4f} origin_test_PESQ : {:.4f} origin_test_STOI : {:.4f} PESQ : {:.4f} STOI : {:.4f}'.format(
                avg_test_loss, origin_test_PESQ, origin_test_STOI,
                test_PESQ, test_STOI))
Exemplo n.º 7
0
def main():
    test_dataset = AudioDataset(data_type=args.test_set)
    test_data_loader = DataLoader(dataset=test_dataset,
                                  batch_size=args.batch_size,
                                  collate_fn=test_dataset.collate,
                                  shuffle=False,
                                  num_workers=0)

    # model select
    print('Model initializing\n')
    net = torch.nn.DataParallel(
        AttentionModel(257,
                       hidden_size=args.hidden_size,
                       dropout_p=args.dropout_p,
                       use_attn=args.attn_use,
                       stacked_encoder=args.stacked_encoder,
                       attn_len=args.attn_len))
    # net = AttentionModel(257, 112, dropout_p = args.dropout_p, use_attn = args.attn_use)
    net = net.cuda()
    print(net)

    optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)

    # Check point load

    print('Trying Checkpoint Load\n')
    ckpt_dir = args.ck_dir
    best_PESQ = 0.
    best_STOI = 0.
    ckpt_path = os.path.join(ckpt_dir, args.ck_name)

    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path)
        try:
            net.load_state_dict(ckpt['model'])
            optimizer.load_state_dict(ckpt['optimizer'])
            # best_STOI = ckpt['best_STOI']

            print('checkpoint is loaded !')
            # print('current best loss : %.4f' % best_loss)
        except RuntimeError as e:
            print('wrong checkpoint\n')
    else:
        print('checkpoint not exist!')
        # print('current best loss : %.4f' % best_loss)

    # test phase
    n = 0
    avg_test_loss = 0

    net.eval()
    with torch.no_grad():
        audio_config = {
            "frame_length": 25,
            "frame_shift": 10,
        }
        test_bar = tqdm(test_data_loader)
        for input in test_bar:
            test_mixed, test_clean, seq_len = map(lambda x: x.cuda(), input)
            lt_test_mixed_vstack = []
            lt_test_clean_vstack = []
            for i in range(len(test_mixed)):
                test_mixed_vstack = torchaudio.compliance.kaldi.fbank(
                    test_mixed[i].unsqueeze(0),
                    num_mel_bins=40,
                    channel=-1,
                    sample_frequency=16000,
                    **audio_config)
                test_clean_vstack = torchaudio.compliance.kaldi.fbank(
                    test_clean[i].unsqueeze(0),
                    num_mel_bins=40,
                    channel=-1,
                    sample_frequency=16000,
                    **audio_config)

                test_mixed_vstack_data = test_mixed_vstack.transpose(
                    0, 1).unsqueeze(0).detach()
                test_clean_vstack_data = test_clean_vstack.transpose(
                    0, 1).unsqueeze(0).detach()

                lt_test_mixed_vstack.append(test_mixed_vstack_data)
                lt_test_clean_vstack.append(test_clean_vstack_data)

            lt_test_mixed_feat = torch.cat(lt_test_mixed_vstack,
                                           dim=0).transpose(1, 2)
            lt_test_clean_feat = torch.cat(lt_test_clean_vstack,
                                           dim=0).transpose(1, 2)

            out_lt_test_mixed_feat, attn_weight = net(lt_test_mixed_feat)

            test_loss = F.mse_loss(out_lt_test_mixed_feat, lt_test_clean_feat,
                                   True)

            for i in range(len(test_mixed)):
                librosa.output.write_wav(
                    'test_out.wav', logits_audio[i].cpu().data.numpy()
                    [:seq_len[i].cpu().data.numpy()], 16000)

            avg_test_loss += test_loss
            n += 1

        avg_test_loss /= n
        print('test loss : {:.4f} '.format(avg_test_loss, ))
Exemplo n.º 8
0
def main():
    #model select
    print('Model initializing\n')
    net = torch.nn.DataParallel(AttentionModel(257, hidden_size = args.hidden_size, dropout_p = args.dropout_p, use_attn = args.attn_use, stacked_encoder = args.stacked_encoder, attn_len = args.attn_len))

    #Check point load
    print('Trying Checkpoint Load\n')
    best_PESQ = 0.
    best_STOI = 0.
    ckpt_path = args.ckpt_path

    if os.path.exists(ckpt_path):
    	ckpt = torch.load(ckpt_path)
    	try:
       	    net.load_state_dict(ckpt['model'])
            net = net.module # uncover DataParallel
            best_STOI = ckpt['best_STOI']

            print('checkpoint is loaded !')
            print('current best loss : %.4f' % best_STOI)
    	except RuntimeError as e:
            print('wrong checkpoint\n')
    else:
        print('checkpoint not exist!')
        print('current best loss : %.4f' % best_STOI)

    #test phase
    net.eval()
    with torch.no_grad():
        inputData, sr = librosa.load(args.noisy_wav, sr=None)
        outputData, sr = librosa.load(args.clean_wav, sr=None)
        inputData = np.float32(inputData)
        outputData = np.float32(outputData)
        mixed_audio = torch.from_numpy(inputData).type(torch.FloatTensor)
        clean_audio = torch.from_numpy(outputData).type(torch.FloatTensor)

        mixed = stft(mixed_audio)
        mixed = mixed.unsqueeze(0)
        mixed = mixed.transpose(1,2)
        cleaned = stft(clean_audio)
        cleaned = cleaned.unsqueeze(0)
        cleaned = cleaned.transpose(1,2)
        real, imag = mixed[..., 0], mixed[..., 1]
        clean_real, clean_imag = cleaned[..., 0], cleaned[..., 1]
        mag = torch.sqrt(real**2 + imag**2)
        clean_mag = torch.sqrt(clean_real**2 + clean_imag**2)
        phase = torch.atan2(imag, real)

        logits_mag, logits_attn_weight = net(mag)
        logits_real = logits_mag * torch.cos(phase)
        logits_imag = logits_mag * torch.sin(phase)
        logits_real, logits_imag = torch.squeeze(logits_real, 1), torch.squeeze(logits_imag, 1)
        logits_real = logits_real.transpose(1,2)
        logits_imag = logits_imag.transpose(1,2)

        logits_audio = istft(logits_real, logits_imag, inputData.shape[0])
        logits_audio = torch.squeeze(logits_audio, dim=1)

        print(logits_audio[0])
        librosa.output.write_wav('./out.wav', logits_audio[0].cpu().data.numpy(), 16000)
        test_loss = F.mse_loss(logits_mag, clean_mag, True)
        test_PESQ = pesq(outputData, logits_audio[0].detach().cpu().numpy(), 16000)
        test_STOI = stoi(outputData, logits_audio[0].detach().cpu().numpy(), 16000, extended=False)

        print("Saved attention weight visualization to attention_viz.png")
        utils.plot_head_map(logits_attn_weight[0])

        # FIXME - Issue with pcm_f32le. Require pcm_s16le
        print("Saved clean spectrogram visualization to spec_clean.png")
        clean_spect = utils.make_spectrogram_array(args.clean_wav)
        utils.save_spectrogram(clean_spect, 'clean')

        print("Saved noisy spectrogram visualization to spec_noisy.png")
        noisy_spect = utils.make_spectrogram_array(args.noisy_wav)
        utils.save_spectrogram(noisy_spect, 'noisy')

        print("Saved enhanced spectrogram visualization to spec_enhanced.png")
        enhanced_spect = utils.make_spectrogram_array('./out.wav')
        utils.save_spectrogram(enhanced_spect, 'enhanced')

        #test accuracy
        print('test loss : {:.4f} PESQ : {:.4f} STOI : {:.4f}'.format(test_loss, test_PESQ, test_STOI))
Exemplo n.º 9
0
def main():
    summary = SummaryWriter('./log')
    tr_set, dv_set, feat_dim, msg = load_dataset(args.njobs, args.gpu, args.pin_memory,
                                                 config['hparas']['curriculum'] > 0,
                                                 **config['data'])

    verbose(msg)
    # model select
    print('Model initializing\n')
    net = torch.nn.DataParallel(
        AttentionModel(120, hidden_size=args.hidden_size, dropout_p=args.dropout_p, use_attn=args.attn_use,
                       stacked_encoder=args.stacked_encoder, attn_len=args.attn_len))
    # net = AttentionModel(257, 112, dropout_p = args.dropout_p, use_attn = arg0s.attn_use)
    net = net.cuda()
    print(net)

    optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)

    scheduler = ExponentialLR(optimizer, 0.5)

    # check point load
    # Check point load

    print('Trying Checkpoint Load\n')
    ckpt_dir = 'ckpt_dir'
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    best_loss = 200000.
    ckpt_path = os.path.join(ckpt_dir, args.ck_name)
    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path)
        try:
            net.load_state_dict(ckpt['model'])
            optimizer.load_state_dict(ckpt['optimizer'])
            best_loss = ckpt['best_loss']

            print('checkpoint is loaded !')
            print('current best loss : %.4f' % best_loss)
        except RuntimeError as e:
            print('wrong checkpoint\n')
    else:
        print('checkpoint not exist!')
        print('current best loss : %.4f' % best_loss)

    print('Training Start!')
    # train
    iteration = 0
    train_losses = []
    test_losses = []
    for epoch in range(args.num_epochs):
        n = 0
        avg_loss = 0
        net.train()
        for input in tqdm(tr_set):
            tr_noisy_set, feat_dim = load_noisy_dataset("train", input[0], args.njobs,
                                                        args.gpu,
                                                        args.pin_memory,
                                                        config['hparas']['curriculum'] > 0,
                                                        **config['data_noisy'])
            for input_noisy in tr_noisy_set:
                train_clean_feat, feat_len = fetch_data(input)
                train_noisy_feat, feat_len = fetch_data(input_noisy)

                iteration += 1

                # feed data
                train_mixed_feat, attn_weight = net(train_noisy_feat)
                if train_mixed_feat.shape == train_clean_feat.shape:
                    loss = F.mse_loss(train_mixed_feat, train_clean_feat, True)

                    if torch.any(torch.isnan(loss)):
                        torch.save(
                            {'clean_mag': train_clean_feat, 'noisy_mag': train_noisy_feat, 'out_mag': train_mixed_feat},
                            'nan_mag')
                        raise ('loss is NaN')
                    avg_loss += loss.item()

                    n += 1
                    # gradient optimizer
                    optimizer.zero_grad()

                    loss.backward()

                    # update weight
                    optimizer.step()

        avg_loss /= n
        print('result:')
        print('[epoch: {}, iteration: {}] avg_loss : {:.4f}'.format(epoch, iteration, avg_loss))

        summary.add_scalar('Train Loss', avg_loss, iteration)

        train_losses.append(avg_loss)
        if (len(train_losses) > 2) and (train_losses[-2] < avg_loss):
            print("Learning rate Decay")
            scheduler.step()

        # test phase
        n = 0
        avg_test_loss = 0
        net.eval()
        with torch.no_grad():
            for input in tqdm(dv_set):
                dv_noisy_set, feat_dim = load_noisy_dataset("dev", input[0], args.njobs,
                                                            args.gpu,
                                                            args.pin_memory,
                                                            config['hparas']['curriculum'] > 0,
                                                            **config['data_noisy'])
                for input_noisy in dv_noisy_set:
                    test_clean_feat = input[1].to(device='cuda')
                    test_noisy_feat = input_noisy[1].to(device='cuda')

                    test_mixed_feat, logits_attn_weight = net(test_noisy_feat)
                    if test_mixed_feat.shape == test_clean_feat.shape:
                        test_loss = F.mse_loss(test_mixed_feat, test_clean_feat, True)

                        avg_test_loss += test_loss.item()
                        n += 1

            avg_test_loss /= n

            test_losses.append(avg_test_loss)
            summary.add_scalar('Test Loss', avg_test_loss, iteration)

            print('[epoch: {}, iteration: {}] test loss : {:.4f} '.format(epoch, iteration, avg_test_loss))
            if avg_test_loss < best_loss:
                best_loss = avg_test_loss
                # Note: optimizer also has states ! don't forget to save them as well.
                ckpt = {'model': net.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'best_loss': best_loss}
                torch.save(ckpt, ckpt_path)
                print('checkpoint is saved !')
Exemplo n.º 10
0
def main():
    summary = SummaryWriter('./log')
    # os.system('tensorboard --logdir=log')
    #
    # set Hyper parameter
    # json_path = os.path.join(args.model_dir)
    # params = train_utils.Params(json_path)

    # data loader
    train_dataset = AudioDataset(data_type='train')
    # modify:num_workers=4
    train_data_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, collate_fn=train_dataset.collate,
                                   shuffle=True, num_workers=4)
    test_dataset = AudioDataset(data_type='valid')
    test_data_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, collate_fn=test_dataset.collate,
                                  shuffle=False, num_workers=4)

    # # data loader
    # train_dataset = AudioDataset(data_type='test')
    # # modify:num_workers=4
    # train_data_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, collate_fn=train_dataset.collate,
    #                                shuffle=True, num_workers=0)
    # test_dataset = AudioDataset(data_type='test')
    # test_data_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, collate_fn=test_dataset.collate,
    #                               shuffle=False, num_workers=0)

    # model select
    print('Model initializing\n')
    net = torch.nn.DataParallel(
        AttentionModel(40, hidden_size=args.hidden_size, dropout_p=args.dropout_p, use_attn=args.attn_use,
                       stacked_encoder=args.stacked_encoder, attn_len=args.attn_len))
    # net = AttentionModel(257, 112, dropout_p = args.dropout_p, use_attn = arg0s.attn_use)
    net = net.cuda()
    print(net)

    optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)

    scheduler = ExponentialLR(optimizer, 0.5)

    # check point load
    # Check point load

    print('Trying Checkpoint Load\n')
    # ckpt_dir = 'ckpt_dir_stoi'
    ckpt_dir = 'ckpt_dir'
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    best_PESQ = 0.
    best_STOI = 0.
    best_loss = 200000.
    ckpt_path = os.path.join(ckpt_dir, args.ck_name)
    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path)
        try:
            net.load_state_dict(ckpt['model'])
            optimizer.load_state_dict(ckpt['optimizer'])
            best_loss = ckpt['best_loss']

            print('checkpoint is loaded !')
            print('current best loss : %.4f' % best_loss)
        except RuntimeError as e:
            print('wrong checkpoint\n')
    else:
        print('checkpoint not exist!')
        print('current best loss : %.4f' % best_loss)

    print('Training Start!')
    # train
    iteration = 0
    train_losses = []
    test_losses = []
    for epoch in range(args.num_epochs):
        train_bar = tqdm(train_data_loader)
        audio_config = {
            "frame_length": 25,
            "frame_shift": 10,
        }
        # train_bar = train_data_loader876\
        n = 0
        avg_loss = 0
        net.train()
        for input in train_bar:
            iteration += 1
            # load data
            train_mixed, train_clean, seq_len = map(lambda x: x.cuda(), input)
            # -----------------------------------
            # lt_train_mixed_vstack = []
            # lt_train_clean_vstack = []
            # for i in range(len(train_mixed)):
            #     train_mixed_vstack = torchaudio.compliance.kaldi.fbank(train_mixed[i].unsqueeze(0), num_mel_bins=40,
            #                                                            channel=-1,
            #                                                            sample_frequency=16000,
            #                                                            **audio_config)
            #     train_clean_vstack = torchaudio.compliance.kaldi.fbank(train_clean[i].unsqueeze(0), num_mel_bins=40,
            #                                                            channel=-1,
            #                                                            sample_frequency=16000,
            #                                                            **audio_config)
            #
            #     train_mixed_vstack_data = train_mixed_vstack.transpose(0, 1).unsqueeze(0).detach()
            #     train_clean_vstack_data = train_clean_vstack.transpose(0, 1).unsqueeze(0).detach()
            #
            #     lt_train_mixed_vstack.append(train_mixed_vstack_data)
            #     lt_train_clean_vstack.append(train_clean_vstack_data)
            #
            # lt_train_mixed_feat = torch.cat(lt_train_mixed_vstack, dim=0).transpose(1, 2)
            # lt_train_clean_feat = torch.cat(lt_train_clean_vstack, dim=0).transpose(1, 2)

            # ------------------------

            train_mixed_feat = train_mixed.reshape(len(train_mixed), int(len(train_mixed[0]) / num_fbank),
                                                   num_fbank).to(
                device='cuda')
            train_clean_feat = train_clean.reshape(len(train_clean), int(len(train_clean[0]) / num_fbank),
                                                   num_fbank).to(
                device='cuda')

            # feed data
            out_train_mixed_feat, attn_weight = net(train_mixed_feat)

            # # feed data
            # out_lt_train_mixed_feat, attn_weight = net(lt_train_mixed_feat)

            loss = F.mse_loss(out_train_mixed_feat, train_clean_feat, True)
            if torch.any(torch.isnan(loss)):
                torch.save(
                    {'clean_mag': train_clean_feat, 'out_mag': train_mixed_feat, 'mag': out_train_mixed_feat},
                    'nan_mag')
                raise ('loss is NaN')
            avg_loss += loss.item()
            n += 1
            # gradient optimizer
            optimizer.zero_grad()

            # backpropagate LOSS20+

            loss.backward()

            # update weight
            optimizer.step()

        avg_loss /= n
        print('result:')
        print('[epoch: {}, iteration: {}] avg_loss : {:.4f}'.format(epoch, iteration, avg_loss))

        summary.add_scalar('Train Loss', avg_loss, iteration)

        train_losses.append(avg_loss)
        if (len(train_losses) > 2) and (train_losses[-2] < avg_loss):
            print("Learning rate Decay")
            scheduler.step()

        # test phase
        n = 0
        avg_test_loss = 0
        test_bar = tqdm(test_data_loader)

        net.eval()
        with torch.no_grad():
            for input in test_bar:
                test_mixed, test_clean, seq_len = map(lambda x: x.cuda(), input)
                test_mixed_feat = test_mixed.reshape(len(test_mixed), int(len(test_mixed[0]) / num_fbank),
                                                     num_fbank).to(
                    device='cuda')
                test_clean_feat = test_clean.reshape(len(test_clean), int(len(test_clean[0]) / num_fbank),
                                                     num_fbank).to(
                    device='cuda')

                logits_test_mixed_feat, logits_attn_weight = net(test_mixed_feat)

                test_loss = F.mse_loss(logits_test_mixed_feat, test_clean_feat, True)

                # lt_test_mixed_vstack = []
                # lt_test_clean_vstack = []
                # for i in range(len(test_mixed)):
                #     test_mixed_vstack = torchaudio.compliance.kaldi.fbank(test_mixed[i].unsqueeze(0), num_mel_bins=40,
                #                                                           channel=-1,
                #                                                           sample_frequency=16000,
                #                                                           **audio_config)
                #     test_clean_vstack = torchaudio.compliance.kaldi.fbank(test_clean[i].unsqueeze(0), num_mel_bins=40,
                #                                                           channel=-1,
                #                                                           sample_frequency=16000,
                #                                                           **audio_config)
                #
                #     test_mixed_vstack_data = test_mixed_vstack.transpose(0, 1).unsqueeze(0).detach()
                #     test_clean_vstack_data = test_clean_vstack.transpose(0, 1).unsqueeze(0).detach()
                #
                #     lt_test_mixed_vstack.append(test_mixed_vstack_data)
                #     lt_test_clean_vstack.append(test_clean_vstack_data)
                #
                # lt_test_mixed_feat = torch.cat(lt_test_mixed_vstack, dim=0).transpose(1, 2)
                # lt_test_clean_feat = torch.cat(lt_test_clean_vstack, dim=0).transpose(1, 2)
                #
                # out_lt_test_mixed_feat, attn_weight = net(lt_test_mixed_feat)

                # test_loss = F.mse_loss(out_lt_test_mixed_feat, lt_test_clean_feat, True)

                avg_test_loss += test_loss.item()
                n += 1

            avg_test_loss /= n

            test_losses.append(avg_test_loss)
            summary.add_scalar('Test Loss', avg_test_loss, iteration)

            print('[epoch: {}, iteration: {}] test loss : {:.4f} '.format(epoch, iteration, avg_test_loss))
            if avg_test_loss < best_loss:
                best_loss = avg_test_loss
                # Note: optimizer also has states ! don't forget to save them as well.
                ckpt = {'model': net.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'best_loss': best_loss}
                torch.save(ckpt, ckpt_path)
                print('checkpoint is saved !')