Example #1
0
def main(opts):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # build network
    minions_cfg = pase_parser(opts.minions_cfg, do_losses=False)
    remove_Dcfg(minions_cfg)
    pase = wf_builder(opts.cfg)
    model = Waveminionet(minions_cfg=minions_cfg,
                         num_devices=0,
                         pretrained_ckpts=opts.ckpt,
                         z_minion=False,
                         frontend=pase)
    model.eval()
    model.to(device)
    transf = Reverb(['data/omologo_revs/IRs_2/IR_223108.imp'], ir_fmt='imp')
    minion = model.minions[0]
    minion.loss = None
    pase = model.frontend
    #print(opts.in_files)
    in_files = [os.path.join(opts.files_root, inf) for inf in opts.in_files]
    wavs = []
    wfiles = []
    max_len = 0
    print('Total batches: ', len(in_files) // opts.batch_size)
    with torch.no_grad():
        for wi, wfile in tqdm.tqdm(enumerate(in_files, start=1),
                                   total=len(in_files)):
            wfiles.append(wfile)
            wav, rate = sf.read(wfile)
            wavs.append(wav)
            if len(wav) > max_len:
                max_len = len(wav)
            if wi % opts.batch_size == 0 or wi >= len(in_files):
                lens = []
                batch = []
                for bi in range(len(wavs)):
                    P_ = max_len - len(wavs[bi])
                    lens.append(len(wavs[bi]))
                    if P_ > 0:
                        pad = np.zeros((P_))
                        wav_ = np.concatenate((wavs[bi], pad), axis=0)
                    else:
                        wav_ = wavs[bi]
                    wav = torch.FloatTensor(wav_)
                    wav_r = transf({'chunk': wav})
                    batch.append(wav_r['chunk'].view(1, 1, -1))
                batch = torch.cat(batch, dim=0)
                x = batch.to(device)
                h = pase(x)
                #print('frontend size: ', h.size())
                y = minion(h).cpu()
                for bi in range(len(wavs)):
                    bname = os.path.basename(wfiles[bi])
                    y_ = y[bi].squeeze().data.numpy()
                    y_ = y_[:lens[bi]]
                    sf.write(os.path.join(opts.out_path, '{}'.format(bname)),
                             y_, 16000)
                    x_ = x[bi].squeeze().data.numpy()
                    x_ = x_[:lens[bi]]
                    sf.write(
                        os.path.join(opts.out_path, 'input_{}'.format(bname)),
                        x_, 16000)
                max_len = 0
                wavs = []
                wfiles = []
                batch = None
    """
Example #2
0
def train(opts):
    CUDA = True if torch.cuda.is_available() and not opts.no_cuda else False
    device = 'cuda' if CUDA else 'cpu'
    num_devices = 1
    np.random.seed(opts.seed)
    random.seed(opts.seed)
    torch.manual_seed(opts.seed)
    if CUDA:
        torch.cuda.manual_seed_all(opts.seed)
        num_devices = torch.cuda.device_count()
        print('[*] Using CUDA {} devices'.format(num_devices))
    else:
        print('[!] Using CPU')
    print('Seeds initialized to {}'.format(opts.seed))

    # ---------------------
    # Build Model
    if opts.fe_cfg is not None:
        with open(opts.fe_cfg, 'r') as fe_cfg_f:
            fe_cfg = json.load(fe_cfg_f)
            print(fe_cfg)
    else:
        fe_cfg = None
    minions_cfg = pase_parser(opts.net_cfg)
    make_transforms(opts, minions_cfg)
    model = Waveminionet(minions_cfg=minions_cfg,
                         adv_loss=opts.adv_loss,
                         num_devices=num_devices,
                         pretrained_ckpt=opts.pretrained_ckpt,
                         frontend_cfg=fe_cfg)

    print(model)
    if opts.net_ckpt is not None:
        model.load_pretrained(opts.net_ckpt, load_last=True, verbose=True)
    print('Frontend params: ', model.frontend.describe_params())
    model.to(device)
    trans = make_transforms(opts, minions_cfg)
    print(trans)
    # Build Dataset(s) and DataLoader(s)
    dset = PairWavDataset(opts.data_root,
                          opts.data_cfg,
                          'train',
                          transform=trans,
                          preload_wav=opts.preload_wav)
    dloader = DataLoader(dset,
                         batch_size=opts.batch_size,
                         shuffle=True,
                         collate_fn=DictCollater(),
                         num_workers=opts.num_workers,
                         pin_memory=CUDA)
    # Compute estimation of bpe. As we sample chunks randomly, we
    # should say that an epoch happened after seeing at least as many
    # chunks as total_train_wav_dur // chunk_size
    bpe = (dset.total_wav_dur // opts.chunk_size) // opts.batch_size
    opts.bpe = bpe
    if opts.do_eval:
        va_dset = PairWavDataset(opts.data_root,
                                 opts.data_cfg,
                                 'valid',
                                 transform=trans,
                                 preload_wav=opts.preload_wav)
        va_dloader = DataLoader(va_dset,
                                batch_size=opts.batch_size,
                                shuffle=False,
                                collate_fn=DictCollater(),
                                num_workers=opts.num_workers,
                                pin_memory=CUDA)
        va_bpe = (va_dset.total_wav_dur // opts.chunk_size) // opts.batch_size
        opts.va_bpe = va_bpe
    else:
        va_dloader = None
    # fastet lr to MI
    #opts.min_lrs = {'mi':0.001}
    model.train_(dloader, vars(opts), device=device, va_dloader=va_dloader)
Example #3
0
def eval(opts):
    CUDA = True if torch.cuda.is_available() and not opts.no_cuda else False
    device = 'cuda' if CUDA else 'cpu'
    np.random.seed(opts.seed)
    random.seed(opts.seed)
    torch.manual_seed(opts.seed)
    if CUDA:
        torch.cuda.manual_seed_all(opts.seed)
    print('Seeds initialized to {}'.format(opts.seed))
    # ---------------------
    # Transforms
    trans = Compose([
        ToTensor(),
        MIChunkWav(opts.chunk_size, random_scale=opts.random_scale),
        LPS(opts.nfft, hop=opts.stride, win=400),
        MFCC(hop=opts.stride),
        Prosody(hop=opts.stride, win=400),
        ZNorm(opts.stats)
    ])
    print(trans)

    # ---------------------
    # Build Dataset(s) and DataLoader(s)
    dset = PairWavDataset(opts.data_root,
                          opts.data_cfg,
                          'valid',
                          transform=trans)
    dloader = DataLoader(dset,
                         batch_size=opts.batch_size,
                         shuffle=False,
                         collate_fn=DictCollater(),
                         num_workers=opts.num_workers)
    # Compute estimation of bpe. As we sample chunks randomly, we
    # should say that an epoch happened after seeing at least as many
    # chunks as total_train_wav_dur // chunk_size
    bpe = (dset.total_wav_dur // opts.chunk_size) // opts.batch_size

    # ---------------------
    # Build Model
    if opts.fe_cfg is not None:
        with open(opts.fe_cfg, 'r') as fe_cfg_f:
            fe_cfg = json.load(fe_cfg_f)
            print(fe_cfg)
    else:
        fe_cfg = None
    model = Waveminionet(minions_cfg=pase_parser(opts.net_cfg),
                         adv_loss=opts.adv_loss,
                         pretrained_ckpt=opts.pretrained_ckpt,
                         frontend_cfg=fe_cfg)

    print(model)
    model.to(device)
    writer = SummaryWriter(opts.save_path)
    if opts.max_epoch is not None:
        # just make a sequential search til max epoch ckpts
        ckpts = ['fullmodel_e{}.ckpt'.format(e) for e in range(opts.max_epoch)]
    else:
        ckpts = opts.ckpts
    for model_ckpt in ckpts:
        # name format is fullmodel_e{}.ckpt
        epoch = int(model_ckpt.split('_')[-1].split('.')[0][1:])
        model_ckpt = os.path.join(opts.ckpt_root, model_ckpt)
        print('Loading ckpt ', model_ckpt)
        model.load_pretrained(model_ckpt, load_last=True, verbose=False)
        model.eval_(dloader,
                    opts.batch_size,
                    bpe,
                    log_freq=opts.log_freq,
                    epoch_idx=epoch,
                    writer=writer,
                    device=device)
Example #4
0
def train(opts):
    CUDA = True if torch.cuda.is_available() and not opts.no_cuda else False
    device = 'cuda' if CUDA else 'cpu'
    num_devices = 1
    np.random.seed(opts.seed)
    random.seed(opts.seed)
    torch.manual_seed(opts.seed)
    if CUDA:
        torch.cuda.manual_seed_all(opts.seed)
        num_devices = torch.cuda.device_count()
        print('[*] Using CUDA {} devices'.format(num_devices))
    else:
        print('[!] Using CPU')
    print('Seeds initialized to {}'.format(opts.seed))

    # ---------------------
    # Build Model
    frontend = wf_builder(opts.fe_cfg)
    minions_cfg = pase_parser(opts.net_cfg,
                              batch_acum=opts.batch_acum,
                              device=device,
                              frontend=frontend)
    model = Waveminionet(minions_cfg=minions_cfg,
                         adv_loss=opts.adv_loss,
                         num_devices=num_devices,
                         frontend=frontend)

    print(model)
    print('Frontend params: ', model.frontend.describe_params())
    model.to(device)
    trans = make_transforms(opts, minions_cfg)
    print(trans)
    if opts.dtrans_cfg is not None:
        with open(opts.dtrans_cfg, 'r') as dtr_cfg:
            dtr = json.load(dtr_cfg)
            #dtr['trans_p'] = opts.distortion_p
            dist_trans = config_distortions(**dtr)
            print(dist_trans)
    else:
        dist_trans = None
    # Build Dataset(s) and DataLoader(s)
    dataset = getattr(pase.dataset, opts.dataset)
    dset = dataset(opts.data_root,
                   opts.data_cfg,
                   'train',
                   transform=trans,
                   noise_folder=opts.noise_folder,
                   whisper_folder=opts.whisper_folder,
                   distortion_probability=opts.distortion_p,
                   distortion_transforms=dist_trans,
                   preload_wav=opts.preload_wav)
    dloader = DataLoader(dset,
                         batch_size=opts.batch_size,
                         shuffle=True,
                         collate_fn=DictCollater(),
                         num_workers=opts.num_workers,
                         pin_memory=CUDA)
    # Compute estimation of bpe. As we sample chunks randomly, we
    # should say that an epoch happened after seeing at least as many
    # chunks as total_train_wav_dur // chunk_size
    bpe = (dset.total_wav_dur // opts.chunk_size) // opts.batch_size
    opts.bpe = bpe
    if opts.do_eval:
        va_dset = dataset(opts.data_root,
                          opts.data_cfg,
                          'valid',
                          transform=trans,
                          noise_folder=opts.noise_folder,
                          whisper_folder=opts.whisper_folder,
                          distortion_probability=opts.distortion_p,
                          distortion_transforms=dist_trans,
                          preload_wav=opts.preload_wav)
        va_dloader = DataLoader(va_dset,
                                batch_size=opts.batch_size,
                                shuffle=False,
                                collate_fn=DictCollater(),
                                num_workers=opts.num_workers,
                                pin_memory=CUDA)
        va_bpe = (va_dset.total_wav_dur // opts.chunk_size) // opts.batch_size
        opts.va_bpe = va_bpe
    else:
        va_dloader = None
    # fastet lr to MI
    #opts.min_lrs = {'mi':0.001}
    model.train_(dloader, vars(opts), device=device, va_dloader=va_dloader)
def eval(opts):
    CUDA = True if torch.cuda.is_available() and not opts.no_cuda else False
    device = 'cuda' if CUDA else 'cpu'
    np.random.seed(opts.seed)
    random.seed(opts.seed)
    torch.manual_seed(opts.seed)
    if CUDA:
        torch.cuda.manual_seed_all(opts.seed)
    print('Seeds initialized to {}'.format(opts.seed))
    # ---------------------
    # Transforms
    trans = Compose([
        ToTensor(),
        MIChunkWav(opts.chunk_size, random_scale=opts.random_scale),
        Prosody(hop=opts.stride, win=400)
    ])

    with open(opts.stats, 'rb') as stats_f:
        stats = pickle.load(stats_f)

    # ---------------------
    # Build Dataset(s) and DataLoader(s)
    dset = PairWavDataset(opts.data_root,
                          opts.data_cfg,
                          'test',
                          transform=trans)
    dloader = DataLoader(dset,
                         batch_size=opts.batch_size,
                         shuffle=False,
                         collate_fn=DictCollater(),
                         num_workers=opts.num_workers)
    # Compute estimation of bpe. As we sample chunks randomly, we
    # should say that an epoch happened after seeing at least as many
    # chunks as total_train_wav_dur // chunk_size
    bpe = (dset.total_wav_dur // opts.chunk_size) // opts.batch_size

    # ---------------------
    # Build Model
    if opts.fe_cfg is not None:
        with open(opts.fe_cfg, 'r') as fe_cfg_f:
            fe_cfg = json.load(fe_cfg_f)
            print(fe_cfg)
    else:
        fe_cfg = None
    model = Waveminionet(minions_cfg=pase_parser(opts.net_cfg),
                         frontend_cfg=fe_cfg)

    print(model)
    model.to(device)

    ckpts = opts.ckpts
    use_epid = False
    if opts.ckpt_epochs is not None:
        use_epid = True
        ckpts = opts.ckpt_epochs
    if ckpts is None:
        raise ValueError('Please specify either ckpts or ckpt_epochs')

    if opts.ckpt_root is None:
        raise ValueError('Please specify ckpt_root!')

    ckpts_res = []

    for ckpt in ckpts:
        if use_epid:
            ckpt_name = 'fullmodel_e{}.ckpt'.format(ckpt)
        else:
            ckpt_name = ckpt
        ckpt_path = os.path.join(opts.ckpt_root, ckpt_name)
        print('Loading ckpt: ', ckpt_path)
        model.load_pretrained(ckpt_path, load_last=True, verbose=True)

        # select prosodic minion
        pmodel = None
        for minion in model.minions:
            if 'prosody' in minion.name:
                pmodel = minion

        # select frontend
        fe = model.frontend

        ckpts_res.append(
            forward_dloader(dloader, bpe, fe, pmodel, stats, opts.tags,
                            device))
        print('Results for ckpt {}'.format(ckpt_name))
        print('-' * 25)
        for k, v in ckpts_res[-1].items():
            print('{}: {}'.format(k, np.mean(v)))
        print('=' * 25)

    with open(opts.out_file, 'w') as out_f:
        out_f.write(json.dumps(ckpts_res, indent=2))