filenames = [pjoin(root_dir, filename) + postfix for filename in filenames]
    return filenames


def print_flags():
    print('--------------------------- Flags -----------------------------')
    for flag in vars(args):
        print('{} : {}'.format(flag, getattr(args, flag)))
    print('{} : {}'.format('device', device))
    print('Actual batch size {} x {} = {}'.format(args.bs_train, args.num_segments, args.bs_train * args.num_segments))


if __name__ == '__main__':
    print('---------------------------------- Data Preparation -----------------------------')
    data_tr = DataETL('train', signal_filelist_path=args.train_speech_seeds, noise_filelist_path=args.train_noise_seeds,
                      feature=args.feature, short_version=False, slice_win=args.slice_win, num_segments=1,
                      mute_random=False, mute_random_snr=False, padding_slice=False, visualise=False)
    dl_tr = DataLoader(data_tr, shuffle=True, batch_size=args.bs_train, num_workers=16, drop_last=False)

    data_va = DataETL('valid', signal_filelist_path=args.valid_speech_seeds, noise_filelist_path=args.valid_noise_seeds,
                      feature=args.feature, slice_win=args.slice_win,
                      mute_random=True, mute_random_snr=True, padding_slice=False, visualise=True)
    dl_va = DataLoader(data_va, shuffle=False, batch_size=args.bs_valid, num_workers=0, drop_last=False)

    print('---------------------------------- Build Neural Networks ------------------------')
    ss = SEUNet(pars)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    ss = to_device(ss, device)
    print(ss)

    if args.optimiser == 'amsgrad':
    return filenames


if __name__ == '__main__':
    print('=> Loading model - {}'.format(args.global_steps))
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    se_cunet = torch.load(pjoin(args.save_models, 'seunet_{}.pkl'.format(args.global_steps)))
    se_cunet = to_device(se_cunet, device)
    print('-------------------------------- Validation ---------------------------')
    loss_te = 0.
    counter_te = 0

    with open(args.given_snr_dict, 'rb') as snr_dict:
        given_snr_dict = p.load(snr_dict)
    data_te = DataETL('test', signal_filelist_path=args.test_speech_seeds, noise_filelist_path=args.test_noise_seeds,
                      feature=args.feature, slice_win=args.slice_win,
                      mute_random=True, mute_random_snr=True, padding_slice=False, visualise=True,
                      given_snr_dict=given_snr_dict)
    dl_te = DataLoader(data_te, shuffle=False, batch_size=args.bs_test, num_workers=0, drop_last=False)

    for idx_te, (batch_te, batch_info) in enumerate(dl_te):
        mixed_segments_ = stack_batches(batch_te['features_segments']['mixed'], split='test')
        signal_segments_ = stack_batches(batch_te['features_segments']['signal'], split='test')
        noise_segments_ = stack_batches(batch_te['features_segments']['noise'], split='test')
        target_segments_ = stack_batches(batch_te['features_segments']['target'], split='test')

        mixed_segments_, signal_segments_ = to_device([mixed_segments_, signal_segments_], device)

        with torch.no_grad():
            denoised_ = se_cunet(mixed_segments_)

        mixed_segments_ = mixed_segments_.squeeze()
Exemplo n.º 3
0
                feature,
                nfft,
                win_length=int(self.win_size * fs['audio_sr'].item()),
                hop_length=int(self.hop_size * fs['audio_sr'].item()))
            torchaudio.save(
                pjoin(save2, str(feature_name)) + 'fromFeatures.wav', waveform,
                fs['audio_sr'].item())


if __name__ == '__main__':
    signal_seeds = '/nas/staff/data_work/Sure/Edinburg_Noisy_Speech_Database/Speech/pkls/test.pkl'
    noise_seeds = '/nas/staff/data_work/Sure/Edinburg_Noisy_Speech_Database/Noise/pkls/test.pkl'
    # data_etl = DataETL('train', signal_filelist_path=signal_seeds, noise_filelist_path=noise_seeds,
    #                    feature='log-magnitude', short_version=1, slice_win=5, target_scope=2, num_segments=5,
    #                    mute_random=True, one_sample=True, visualise=True)
    data_etl = DataETL('test',
                       signal_filelist_path=signal_seeds,
                       noise_filelist_path=noise_seeds,
                       feature='time-domain',
                       short_version=2,
                       slice_win=16384,
                       mute_random=True,
                       mute_random_snr=True,
                       padding_slice=False,
                       visualise=True)

    audio_save2 = './audio_demo/'
    dv = DataVisualisation(data_etl, audio_save2)
    dv.visualise(num_segments=3)
    print('--------EOF---------')
import time
from DataETL import DataETL, DataLoader
signal_seeds = '/nas/staff/data_work/Sure/Edinburg_Noisy_Speech_Database/Speech/pkls/train.pkl'
noise_seeds = '/nas/staff/data_work/Sure/Edinburg_Noisy_Speech_Database/Noise/pkls/train.pkl'

data_etl = DataETL('train',
                   signal_filelist_path=signal_seeds,
                   noise_filelist_path=noise_seeds,
                   feature='log-magnitude',
                   slice_win=5,
                   target_scope=2,
                   num_segments=5)
data_loader = DataLoader(data_etl, shuffle=True, batch_size=20, num_workers=16)

current_time = time.clock()
for idx_batch, batch in enumerate(data_loader):
    update_time = time.clock()
    duration = update_time - current_time
    print('No.{}, Loading Time: {}'.format(idx_batch, duration))
    print('Elements in one batch: {}'.format(list(batch.keys())))
    current_time = time.clock()