コード例 #1
0
def main(args):
    # Construct Solver
    # data
    tr_dataset = AudioDataset(args.train_json,
                              sample_rate=args.sample_rate,
                              segment_length=args.segment_length)

    cv_dataset = AudioDataset(
        args.valid_json,
        sample_rate=args.sample_rate,
        segment_length=args.segment_length,
    )

    tr_loader = AudioDataLoader(tr_dataset,
                                batch_size=args.batch_size,
                                shuffle=args.shuffle,
                                num_workers=args.num_workers)

    cv_loader = AudioDataLoader(cv_dataset,
                                batch_size=args.batch_size,
                                num_workers=0)

    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}
    # model
    model = ConvTasNet(args.N,
                       args.L,
                       args.B,
                       args.H,
                       args.P,
                       args.X,
                       args.R,
                       args.C,
                       norm_type=args.norm_type,
                       causal=args.causal,
                       mask_nonlinear=args.mask_nonlinear)
    print(model)
    if args.use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()
    # optimizer
    lr = args.lr / args.batch_per_step
    if args.optimizer == 'sgd':
        optimizier = torch.optim.SGD(model.parameters(),
                                     lr=lr,
                                     momentum=args.momentum,
                                     weight_decay=args.l2)
    elif args.optimizer == 'adam':
        optimizier = torch.optim.Adam(model.parameters(),
                                      lr=lr,
                                      weight_decay=args.l2)
    else:
        print("Not support optimizer")
        return

    # solver
    solver = Solver(data, model, optimizier, args)
    solver.train()
コード例 #2
0
def main(args):

    # data
    tr_dataset = AudioDataset('tr',
                              batch_size=args.batch_size,
                              sample_rate=args.sample_rate,
                              nmic=args.mic)
    cv_dataset = AudioDataset('val',
                              batch_size=args.batch_size,
                              sample_rate=args.sample_rate,
                              nmic=args.mic)
    tr_loader = AudioDataLoader(tr_dataset,
                                batch_size=1,
                                shuffle=args.shuffle,
                                num_workers=0)  #num_workers=0 for PC
    cv_loader = AudioDataLoader(cv_dataset, batch_size=1,
                                num_workers=0)  #num_workers=0 for PC

    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}

    # model
    model = FaSNet_TAC(enc_dim=args.enc_dim,
                       feature_dim=args.feature_dim,
                       hidden_dim=args.hidden_dim,
                       layer=args.layer,
                       segment_size=args.segment_size,
                       nspk=args.nspk,
                       win_len=args.win_len,
                       context_len=args.context_len,
                       sr=args.sample_rate)

    k = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('# of parameters:', k)

    #print(model)
    if args.use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()

    # optimizer
    if args.optimizer == 'sgd':
        optimizier = torch.optim.SGD(model.parameters(),
                                     lr=args.lr,
                                     momentum=args.momentum,
                                     weight_decay=args.l2)
    elif args.optimizer == 'adam':
        optimizier = torch.optim.Adam(model.parameters(),
                                      lr=args.lr,
                                      weight_decay=args.l2)
    else:
        print("Not support optimizer")
        return

    # solver
    solver = Solver(data, model, optimizier, args)
    solver.train()
コード例 #3
0
ファイル: train.py プロジェクト: b06502162Lu/HW3-1_Lu
def main(args):
    # Construct Solver
    # data
    tr_dataset = AudioDataset(args.train_dir,
                              args.batch_size,
                              sample_rate=args.sample_rate,
                              segment=args.segment)
    cv_dataset = AudioDataset(
        args.valid_dir,
        batch_size=1,  # 1 -> use less GPU memory to do cv
        sample_rate=args.sample_rate,
        segment=-1,
        cv_maxlen=args.cv_maxlen)  # -1 -> use full audio
    tr_loader = AudioDataLoader(tr_dataset,
                                batch_size=1,
                                shuffle=args.shuffle,
                                num_workers=4)
    cv_loader = AudioDataLoader(cv_dataset,
                                batch_size=1,
                                num_workers=4,
                                pin_memory=True)
    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}
    # model
    model = ConvTasNet(args.N,
                       args.L,
                       args.B,
                       args.H,
                       args.P,
                       args.X,
                       args.R,
                       args.C,
                       norm_type=args.norm_type,
                       causal=args.causal,
                       mask_nonlinear=args.mask_nonlinear)
    if args.use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()
    # optimizer
    if args.optimizer == 'sgd':
        optimizier = torch.optim.SGD(model.parameters(),
                                     lr=args.lr,
                                     momentum=args.momentum,
                                     weight_decay=args.l2)
    elif args.optimizer == 'adam':
        optimizier = torch.optim.Adam(model.parameters(),
                                      lr=args.lr,
                                      weight_decay=args.l2)
    else:
        print("Not support optimizer")
        return

    # solver
    solver = Solver(data, model, optimizier, args)
    solver.train()
コード例 #4
0
ファイル: train.py プロジェクト: JJoving/SMLAT
def main(args):
    # Construct Solver
    # data
    tr_dataset = AudioDataset(args.train_json, args.batch_size, args.maxlen_in,
                              args.maxlen_out)
    cv_dataset = AudioDataset(args.valid_json, args.batch_size, args.maxlen_in,
                              args.maxlen_out)
    tr_loader = AudioDataLoader(tr_dataset,
                                batch_size=1,
                                num_workers=args.num_workers)
    cv_loader = AudioDataLoader(cv_dataset,
                                batch_size=1,
                                num_workers=args.num_workers)
    # load dictionary and generate char_list, sos_id, eos_id
    char_list, sos_id, eos_id = process_dict(args.dict)
    vocab_size = len(char_list)
    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}
    # model
    encoder = Encoder(args.einput,
                      args.ehidden,
                      args.elayer,
                      dropout=args.edropout,
                      bidirectional=args.ebidirectional,
                      rnn_type=args.etype)
    decoder = Decoder(vocab_size,
                      args.dembed,
                      sos_id,
                      eos_id,
                      args.dhidden,
                      args.dlayer,
                      bidirectional_encoder=args.ebidirectional)
    model = Seq2Seq(encoder, decoder)
    print(model)
    model.cuda()
    # optimizer
    if args.optimizer == 'sgd':
        optimizier = torch.optim.SGD(model.parameters(),
                                     lr=args.lr,
                                     momentum=args.momentum,
                                     weight_decay=args.l2)
    elif args.optimizer == 'adam':
        optimizier = torch.optim.Adam(model.parameters(),
                                      lr=args.lr,
                                      weight_decay=args.l2)
    else:
        print("Not support optimizer")
        return

    # solver
    ctc = 0
    solver = Solver(data, model, optimizier, args)
    solver.train()
コード例 #5
0
ファイル: noise_scheduler_n.py プロジェクト: yhgon/WaveGrad
def run(config, args):
    print(config)
    model = WaveGrad(config).cuda()
    print(f'Number of parameters: {model.nparams}')

    schedule_checkpoint = os.path.join(
        config.training_config.logdir,
        "checkpoint_{}.pt".format(args.checkpointnum))
    load_checkpoint_org(model, schedule_checkpoint)

    dataset = AudioDataset(config, training=False)
    mel_fn = MelSpectrogramFixed(sample_rate=config.data_config.sample_rate,
                                 n_fft=config.data_config.n_fft,
                                 win_length=config.data_config.win_length,
                                 hop_length=config.data_config.hop_length,
                                 f_min=config.data_config.f_min,
                                 f_max=config.data_config.f_max,
                                 n_mels=config.data_config.n_mels,
                                 window_fn=torch.hann_window).cuda()

    iters_best_schedule, stats = benchmark.iters_schedule_grid_search(
        model=model,
        n_iter=args.iter,
        config=config,
        step=args.step,
        test_batch_size=args.schedulebatch,
        path_to_store_stats='{}/gs_stats_{:d}iters.pt'.format(
            args.scheduledir, args.iter),
        verbose=args.verbose)
    torch.save(
        iters_best_schedule,
        '{}/iters{:d}_best_schedule.pt'.format(args.scheduledir, args.iter))

    print(args.iter)
    print(iters_best_schedule)
コード例 #6
0
def estimate_average_rtf_on_filelist(filelist_path,
                                     config,
                                     model,
                                     verbose=True):
    device = next(model.parameters()).device
    config.training_config.test_filelist_path = filelist_path
    dataset = AudioDataset(config, training=False)
    mel_fn = MelSpectrogramFixed(sample_rate=config.data_config.sample_rate,
                                 n_fft=config.data_config.n_fft,
                                 win_length=config.data_config.win_length,
                                 hop_length=config.data_config.hop_length,
                                 f_min=config.data_config.f_min,
                                 f_max=config.data_config.f_max,
                                 n_mels=config.data_config.n_mels,
                                 window_fn=torch.hann_window).to(device)
    rtfs = []
    for i in (tqdm(range(len(dataset))) if verbose else range(len(dataset))):
        datapoint = dataset[i].to(device)
        mel = mel_fn(datapoint)[None]
        start = datetime.now()
        sample = model.forward(mel, store_intermediate_states=False)
        end = datetime.now()
        generation_time = (end - start).total_seconds()
        rtf = compute_rtf(sample,
                          generation_time,
                          sample_rate=config.data_config.sample_rate)
        rtfs.append(rtf)
    average_rtf = np.mean(rtfs)
    std_rtf = np.std(rtfs)

    show_message(f'DEVICE: {device}. average_rtf={average_rtf}, std={std_rtf}',
                 verbose=verbose)

    rtf_stats = {'rtfs': rtfs, 'average': average_rtf, 'std': std_rtf}
    return rtf_stats
コード例 #7
0
def main(args):
    # Construct Solver
    # data
    tr_dataset = AudioDataset(args.train_dir,
                              args.batch_size,
                              sample_rate=args.sample_rate,
                              segment=args.segment)
    cv_dataset = AudioDataset(
        args.valid_dir,
        batch_size=1,  # 1 -> use less GPU memory to do cv
        sample_rate=args.sample_rate,
        segment=-1,
        cv_maxlen=args.cv_maxlen)  # -1 -> use full audio
    tr_loader = AudioDataLoader(tr_dataset,
                                batch_size=1,
                                shuffle=args.shuffle,
                                num_workers=args.num_workers)
    cv_loader = AudioDataLoader(cv_dataset, batch_size=1, num_workers=0)
    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}
    # model
    model = DPTNet(args.N, args.C, args.L, args.H, args.K, args.B)
    #print(model)
    k = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('# of parameters:', k)

    if args.use_cuda:
        os.environ["CUDA_VISIBLE_DEVICES"] = '5,6,7'
        model = torch.nn.DataParallel(model)
        model.cuda()
    # optimizer
    if args.optimizer == 'sgd':
        optimizier = torch.optim.SGD(model.parameters(),
                                     lr=args.lr,
                                     momentum=args.momentum,
                                     weight_decay=args.l2)
    elif args.optimizer == 'adam':
        optimizier = torch.optim.Adam(model.parameters(),
                                      lr=args.lr,
                                      weight_decay=args.l2)
    else:
        print("Not support optimizer")
        return

    # solver
    solver = Solver(data, model, optimizier, args)
    solver.train()
コード例 #8
0
def main(args):
    model = torch.load(args.model).to(device)
    test_data = AudioDataset(args.data, 'test')
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=num_workers)
    test(args, model, test_loader)
コード例 #9
0
ファイル: train.py プロジェクト: entn-at/TasNet
def main(args):
    # Construct Solver
    # data
    tr_dataset = AudioDataset(args.train_dir,
                              args.batch_size,
                              sample_rate=args.sample_rate,
                              L=args.L)
    cv_dataset = AudioDataset(args.valid_dir,
                              args.batch_size,
                              sample_rate=args.sample_rate,
                              L=args.L)
    tr_loader = AudioDataLoader(tr_dataset,
                                batch_size=1,
                                shuffle=args.shuffle,
                                num_workers=args.num_workers)
    cv_loader = AudioDataLoader(cv_dataset,
                                batch_size=1,
                                shuffle=args.shuffle,
                                num_workers=args.num_workers)
    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}
    # model
    model = TasNet(args.L,
                   args.N,
                   args.hidden_size,
                   args.num_layers,
                   bidirectional=args.bidirectional,
                   nspk=args.nspk)
    print(model)
    model.cuda()
    # optimizer
    if args.optimizer == 'sgd':
        optimizier = torch.optim.SGD(model.parameters(),
                                     lr=args.lr,
                                     momentum=args.momentum,
                                     weight_decay=args.l2)
    elif args.optimizer == 'adam':
        optimizier = torch.optim.Adam(model.parameters(),
                                      lr=args.lr,
                                      weight_decay=args.l2)
    else:
        print("Not support optimizer")
        return

    # solver
    solver = Solver(data, model, optimizier, args)
    solver.train()
コード例 #10
0
def main(args):
    # Construct Solver

    # data
    tr_dataset = AudioDataset(args.tr_json,
                              sample_rate=args.sample_rate,
                              segment=args.segment,
                              drop=args.drop)
    cv_dataset = AudioDataset(args.cv_json,
                              sample_rate=args.sample_rate,
                              drop=0,
                              segment=-1)  # -1 -> use full audio
    tr_loader = AudioDataLoader(tr_dataset,
                                batch_size=args.batch_size,
                                shuffle=args.shuffle,
                                num_workers=args.num_workers)
    cv_loader = AudioDataLoader(cv_dataset, batch_size=1, num_workers=0)
    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}

    # model
    # N=512, L=32, B=128, Sc=128, H=512, X=8, R=3, P=3, C=2
    model = ConvTasNet(args.N, args.L, args.B, args.Sc, args.H, args.X, args.R,
                       args.P, args.C)

    print(model)
    if args.use_cuda:
        os.environ["CUDA_VISIBLE_DEVICES"] = '5,6,7'
        model = torch.nn.DataParallel(model)
        model.cuda()
    # optimizer
    if args.optimizer == 'sgd':
        optimizier = torch.optim.SGD(model.parameters(),
                                     lr=args.lr,
                                     momentum=args.momentum,
                                     weight_decay=args.l2)
    elif args.optimizer == 'adam':
        optimizier = torch.optim.Adam(model.parameters(),
                                      lr=args.lr,
                                      weight_decay=args.l2)
    else:
        print("Not support optimizer")
        return

    # solver
    solver = Solver(data, model, optimizier, args)
    solver.train()
コード例 #11
0
ファイル: inference.py プロジェクト: dodoproptit99/WaveGrad
def get_mel(config, model):
    mel_fn = MelSpectrogramFixed(sample_rate=config.data_config.sample_rate,
                                 n_fft=config.data_config.n_fft,
                                 win_length=config.data_config.win_length,
                                 hop_length=config.data_config.hop_length,
                                 f_min=config.data_config.f_min,
                                 f_max=config.data_config.f_max,
                                 n_mels=config.data_config.n_mels,
                                 window_fn=torch.hann_window).cuda()

    dataset = AudioDataset(config, training=False)
    test_batch = dataset.sample_test_batch(1)

    # n_iter = 25
    # path_to_store_schedule = f'schedules/default/{n_iter}iters.pt'

    # iters_best_schedule, stats = iters_schedule_grid_search(
    #     model, config,
    #     n_iter=n_iter,
    #     betas_range=(1e-06, 0.01),
    #     test_batch_size=1, step=1,
    #     path_to_store_schedule=path_to_store_schedule,
    #     save_stats_for_grid=True,
    #     verbose=True, n_jobs=4
    # )

    i = 0
    for test in tqdm(test_batch):
        mel = mel_fn(test[None].cuda())
        start = datetime.now()
        t = time()
        outputs = model.forward(mel, store_intermediate_states=False)
        end = datetime.now()
        print("Time infer: ", str(time() - t))
        outputs = outputs.cpu().squeeze()
        save_path = str(i) + '.wav'
        i += 1
        torchaudio.save(save_path,
                        outputs,
                        sample_rate=config.data_config.sample_rate)
        inference_time = (end - start).total_seconds()
        rtf = compute_rtf(outputs,
                          inference_time,
                          sample_rate=config.data_config.sample_rate)
        show_message(f'Done. RTF estimate:{np.std(rtf)}')
コード例 #12
0
def evaluate(args):
    total_SISNRi = 0
    total_SDRi = 0
    total_cnt = 0

    # Load model
    model = ConvTasNet.load_model(args.model_path)
    print(model)
    model.eval()
    #if args.use_cuda:
    if True:
        model.cuda()

    # Load data
    dataset = AudioDataset(args.data_dir,
                           args.batch_size,
                           sample_rate=args.sample_rate,
                           segment=-1)
    data_loader = AudioDataLoader(dataset, batch_size=1, num_workers=2)

    with torch.no_grad():
        for i, (data) in enumerate(data_loader):
            # Get batch data
            padded_mixture, mixture_lengths, padded_source = data
            #if args.use_cuda:
            if True:
                padded_mixture = padded_mixture.cuda()
                mixture_lengths = mixture_lengths.cuda()
                padded_source = padded_source.cuda()
            # Forward
            estimate_source = model(padded_mixture)  # [B, C, T]
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(padded_source, estimate_source, mixture_lengths)
            # Remove padding and flat
            mixture = remove_pad(padded_mixture, mixture_lengths)
            source = remove_pad(padded_source, mixture_lengths)
            # NOTE: use reorder estimate source
            estimate_source = remove_pad(reorder_estimate_source,
                                         mixture_lengths)
            # for each utterance
            for mix, src_ref, src_est in zip(mixture, source, estimate_source):
                print("Utt", total_cnt + 1)
                # Compute SDRi
                if args.cal_sdr:
                    avg_SDRi = cal_SDRi(src_ref, src_est, mix)
                    total_SDRi += avg_SDRi
                    print("\tSDRi={0:.2f}".format(avg_SDRi))
                # Compute SI-SNRi
                avg_SISNRi = cal_SISNRi(src_ref, src_est, mix)
                print("\tSI-SNRi={0:.2f}".format(avg_SISNRi))
                total_SISNRi += avg_SISNRi
                total_cnt += 1
    if args.cal_sdr:
        print("Average SDR improvement: {0:.2f}".format(total_SDRi /
                                                        total_cnt))
    print("Average SISNR improvement: {0:.2f}".format(total_SISNRi /
                                                      total_cnt))
コード例 #13
0
def main(args):
    # Construct Solver
    # data
    tr_dataset = AudioDataset(args.train_dir, args.batch_size,
                              sample_rate=args.sample_rate, segment=args.segment)
    cv_dataset = AudioDataset(args.valid_dir, batch_size=1,  # 1 -> use less GPU memory to do cv
                              sample_rate=args.sample_rate,
                              segment=-1, cv_maxlen=args.cv_maxlen)  # -1 -> use full audio
    tr_loader = AudioDataLoader(tr_dataset, batch_size=1,
                                shuffle=args.shuffle)
    cv_loader = AudioDataLoader(cv_dataset, batch_size=1)
    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}

    # model
    # model = FURCA(args.W, args.N, args.K, args.C, args.D, args.H, args.E,
    #                    norm_type=args.norm_type, causal=args.causal,
    #                    mask_nonlinear=args.mask_nonlinear)
    model = FaSNet_base(enc_dim=256, feature_dim=64, hidden_dim=128, layer=6, segment_size=250, nspk = 2, win_len = 2)

    print(model)
    if args.use_cuda:
        # model = torch.nn.DataParallel(model)
        model.cuda()
        #model.to(device)
    # optimizer
    if args.optimizer == 'sgd':
        optimizier = torch.optim.SGD(model.parameters(),
                                     lr=args.lr,
                                     momentum=args.momentum,
                                     weight_decay=args.l2)
    elif args.optimizer == 'adam':
        optimizier = torch.optim.Adam(model.parameters(),
                                      lr=args.lr,
                                      weight_decay=args.l2)
    else:
        print("Not support optimizer")
        return

    # solver
    solver = Solver(data, model, optimizier, args)
    solver.train()
コード例 #14
0
def main(train_dir,batch_size,sample_rate, segment,valid_dir,cv_maxlen,shuffle,num_workers,N, L, B, H, P, X, R, C,norm_type, causal, mask_nonlinear,use_cuda,optimizer,lr,momentum,l2):
     # Construct Solver
    # data
    tr_dataset = AudioDataset(train_dir, batch_size,
                              sample_rate=sample_rate, segment=segment)
    cv_dataset = AudioDataset(valid_dir, batch_size=1,  # 1 -> use less GPU memory to do cv
                              sample_rate=sample_rate,
                              segment=-1, cv_maxlen=cv_maxlen)  # -1 -> use full audio
    tr_loader = AudioDataLoader(tr_dataset, batch_size=1,
                                shuffle=shuffle,
                                num_workers=num_workers)
    cv_loader = AudioDataLoader(cv_dataset, batch_size=1,
                                num_workers=0)
    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}
    # model
    model = ConvTasNet(N, L, B, H, P, X, R, C, 
                       norm_type=norm_type, causal=causal,
                       mask_nonlinear=mask_nonlinear)
    print(model)
    if use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()
    # optimizer
    if optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                     lr=lr,
                                     momentum=momentum,
                                     weight_decay=l2)
    elif optimizer == 'adam':
     #fatemeh: change optimizier to optimizer
        optimizer = torch.optim.Adam(model.parameters(),
                                      lr=lr,
                                      weight_decay=l2)
    else:
        print("Not support optimizer")
        return

    # solver
    solver = Solver(data, model, optimizer, use_cuda,epochs,half_lr,early_stop,max_norm,save_folder,checkpoint,continue_from,model_path,print_freq,visdom,visdom_epoch,visdom_id)
    solver.train()
コード例 #15
0
def main(args):
    scheduler = None
    train_data = AudioDataset(args.input, 'train')
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers)

    val_data = AudioDataset(args.input, 'validation')
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=num_workers)

    model = Model(device, method)
    criterion = nn.NLLLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    if args.half_lr:
        print("half learning rate on plateau")
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.8,
                                      patience=2)

    with mlflow.start_run() as run:
        mlflow.log_param("num_epoch", num_epoch)
        mlflow.log_param("lr", lr)
        mlflow.log_param("model_hidden_size", model_hidden_size)
        mlflow.log_param("model_layers", model_layers)

        for epoch in range(num_epoch):
            train_loss = train(epoch, model, train_loader, criterion,
                               optimizer, scheduler, 5)
            if scheduler:
                scheduler.step(train_loss)
            mlflow.pytorch.log_model(model, "classifier")
            validate(args, model, val_loader, criterion)
        torch.save(model, f'{args.output}/classifier.pth')
コード例 #16
0
def estimate_average_rtf_on_filelist(filelist_path,
                                     config,
                                     model,
                                     verbose=True):
    """
    Runs RTF estimation of filelist of audios and computes statistics.
    :param filelist_path (str): path to a filelist with needed audios
    :param config (utils.ConfigWrapper): configuration dict
    :param model (torch.nn.Module): WaveGrad model
    :param verbose (bool, optional): verbosity level
    :return stats: statistics dict
    """
    device = next(model.parameters()).device
    config.training_config.test_filelist_path = filelist_path
    dataset = AudioDataset(config, training=False)
    mel_fn = MelSpectrogramFixed(sample_rate=config.data_config.sample_rate,
                                 n_fft=config.data_config.n_fft,
                                 win_length=config.data_config.win_length,
                                 hop_length=config.data_config.hop_length,
                                 f_min=config.data_config.f_min,
                                 f_max=config.data_config.f_max,
                                 n_mels=config.data_config.n_mels,
                                 window_fn=torch.hann_window).to(device)
    rtfs = []
    for i in (tqdm(range(len(dataset))) if verbose else range(len(dataset))):
        datapoint = dataset[i].to(device)
        mel = mel_fn(datapoint)[None]
        start = datetime.now()
        sample = model.forward(mel, store_intermediate_states=False)
        end = datetime.now()
        generation_time = (end - start).total_seconds()
        rtf = compute_rtf(sample,
                          generation_time,
                          sample_rate=config.data_config.sample_rate)
        rtfs.append(rtf)
    average_rtf = np.mean(rtfs)
    std_rtf = np.std(rtfs)

    show_message(f'DEVICE: {device}. average_rtf={average_rtf}, std={std_rtf}',
                 verbose=verbose)

    rtf_stats = {'rtfs': rtfs, 'average': average_rtf, 'std': std_rtf}
    return rtf_stats
コード例 #17
0
ファイル: train_org.py プロジェクト: yhgon/WaveGrad
def run(config, args):
    show_message('Initializing logger...', verbose=args.verbose)
    logger = Logger(config)

    show_message('Initializing model...', verbose=args.verbose)
    model = WaveGrad(config).cuda()
    show_message(f'Number of parameters: {model.nparams}',
                 verbose=args.verbose)
    mel_fn = MelSpectrogramFixed(sample_rate=config.data_config.sample_rate,
                                 n_fft=config.data_config.n_fft,
                                 win_length=config.data_config.win_length,
                                 hop_length=config.data_config.hop_length,
                                 f_min=config.data_config.f_min,
                                 f_max=config.data_config.f_max,
                                 n_mels=config.data_config.n_mels,
                                 window_fn=torch.hann_window).cuda()

    show_message('Initializing optimizer, scheduler and losses...',
                 verbose=args.verbose)
    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=config.training_config.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=config.training_config.scheduler_step_size,
        gamma=config.training_config.scheduler_gamma)

    show_message('Initializing data loaders...', verbose=args.verbose)
    train_dataset = AudioDataset(config, training=True)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config.training_config.batch_size,
                                  drop_last=True)
    test_dataset = AudioDataset(config, training=False)
    test_dataloader = DataLoader(test_dataset, batch_size=1)
    test_batch = test_dataset.sample_test_batch(
        config.training_config.n_samples_to_test)

    if config.training_config.continue_training:
        show_message('Loading latest checkpoint to continue training...',
                     verbose=args.verbose)
        model, optimizer, iteration = logger.load_latest_checkpoint(
            model, optimizer)
        epoch_size = len(train_dataset) // config.training_config.batch_size
        epoch_start = iteration // epoch_size
    else:
        iteration = 0
        epoch_start = 0

    # Log ground truth test batch
    audios = {
        f'audio_{index}/gt': audio
        for index, audio in enumerate(test_batch)
    }
    logger.log_audios(0, audios)
    specs = {
        f'mel_{index}/gt': mel_fn(audio.cuda()).cpu().squeeze()
        for index, audio in enumerate(test_batch)
    }
    logger.log_specs(0, specs)

    show_message('Start training...', verbose=args.verbose)
    try:
        for epoch in range(epoch_start, config.training_config.n_epoch):
            # Training step
            model.set_new_noise_schedule(
                init=torch.linspace,
                init_kwargs={
                    'steps':
                    config.training_config.training_noise_schedule.n_iter,
                    'start':
                    config.training_config.training_noise_schedule.
                    betas_range[0],
                    'end':
                    config.training_config.training_noise_schedule.
                    betas_range[1]
                })
            for i, batch in enumerate(train_dataloader):
                tic_iter = time.time()
                batch = batch.cuda()
                mels = mel_fn(batch)

                # Training step
                model.zero_grad()
                loss = model.compute_loss(mels, batch)
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    parameters=model.parameters(),
                    max_norm=config.training_config.grad_clip_threshold)
                optimizer.step()
                toc_iter = time.time()
                dur_iter = toc_iter - tic_iter
                dur_iter = np.round(dur_iter, 4)
                loss_stats = {
                    'total_loss': loss.item(),
                    'grad_norm': grad_norm.item()
                }
                iter_size = len(train_dataloader)
                logger.log_training(epoch,
                                    iteration,
                                    i,
                                    iter_size,
                                    loss_stats,
                                    dur_iter,
                                    int(dur_iter * iter_size),
                                    verbose=args.verbose)

                iteration += 1

            # Test step
            if epoch % config.training_config.test_interval == 0:
                model.set_new_noise_schedule(
                    init=torch.linspace,
                    init_kwargs={
                        'steps':
                        config.training_config.test_noise_schedule.n_iter,
                        'start':
                        config.training_config.test_noise_schedule.
                        betas_range[0],
                        'end':
                        config.training_config.test_noise_schedule.
                        betas_range[1]
                    })
                with torch.no_grad():
                    # Calculate test set loss
                    test_loss = 0
                    for i, batch in enumerate(test_dataloader):
                        batch = batch.cuda()
                        mels = mel_fn(batch)
                        test_loss_ = model.compute_loss(mels, batch)
                        test_loss += test_loss_
                    test_loss /= (i + 1)
                    loss_stats = {'total_loss': test_loss.item()}

                    # Restore random batch from test dataset
                    audios = {}
                    specs = {}
                    test_l1_loss = 0
                    test_l1_spec_loss = 0
                    average_rtf = 0

                    for index, test_sample in enumerate(test_batch):
                        test_sample = test_sample[None].cuda()
                        test_mel = mel_fn(test_sample.cuda())

                        start = datetime.now()
                        y_0_hat = model.forward(
                            test_mel, store_intermediate_states=False)
                        y_0_hat_mel = mel_fn(y_0_hat)
                        end = datetime.now()
                        generation_time = (end - start).total_seconds()
                        average_rtf += compute_rtf(
                            y_0_hat, generation_time,
                            config.data_config.sample_rate)

                        test_l1_loss += torch.nn.L1Loss()(y_0_hat,
                                                          test_sample).item()
                        test_l1_spec_loss += torch.nn.L1Loss()(
                            y_0_hat_mel, test_mel).item()

                        audios[f'audio_{index}/predicted'] = y_0_hat.cpu(
                        ).squeeze()
                        specs[f'mel_{index}/predicted'] = y_0_hat_mel.cpu(
                        ).squeeze()

                    average_rtf /= len(test_batch)
                    show_message(f'Device: GPU. average_rtf={average_rtf}',
                                 verbose=args.verbose)

                    test_l1_loss /= len(test_batch)
                    loss_stats['l1_test_batch_loss'] = test_l1_loss
                    test_l1_spec_loss /= len(test_batch)
                    loss_stats['l1_spec_test_batch_loss'] = test_l1_spec_loss

                    logger.log_test(epoch, loss_stats, verbose=args.verbose)
                    #logger.log_audios(epoch, audios)
                    #logger.log_specs(epoch, specs)

                logger.save_checkpoint(iteration, model, optimizer)
            if epoch % (epoch // 10 + 1) == 0:
                scheduler.step()
    except KeyboardInterrupt:
        print('KeyboardInterrupt: training has been stopped.')
        return
コード例 #18
0
ファイル: train_rnnt.py プロジェクト: cloudthink/open_stt_e2e
torch.manual_seed(0)
np.random.seed(0)

labels = Labels()

model = Transducer(128,
                   len(labels),
                   512,
                   256,
                   am_layers=3,
                   lm_layers=3,
                   dropout=0.3,
                   am_checkpoint='exp/am.bin',
                   lm_checkpoint='exp/lm.bin')

train = AudioDataset(
    '/media/lytic/STORE/ru_open_stt_wav/public_youtube1120_hq.txt', labels)
test = AudioDataset(
    '/media/lytic/STORE/ru_open_stt_wav/public_youtube700_val.txt', labels)

train.filter_by_conv(model.encoder.conv)
train.filter_by_length(400)

test.filter_by_conv(model.encoder.conv)
test.filter_by_length(200)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-5)

model.cuda()

sampler = BucketingSampler(train, 32)
コード例 #19
0
def main(args):
    # Construct Solver
    # data
    tr_dataset = AudioDataset(args.train_json,
                              args.batch_size,
                              args.maxlen_in,
                              args.maxlen_out,
                              batch_frames=args.batch_frames)
    cv_dataset = AudioDataset(args.valid_json,
                              args.batch_size,
                              args.maxlen_in,
                              args.maxlen_out,
                              batch_frames=args.batch_frames)
    tr_loader = AudioDataLoader(tr_dataset,
                                batch_size=1,
                                num_workers=args.num_workers,
                                shuffle=args.shuffle,
                                LFR_m=args.LFR_m,
                                LFR_n=args.LFR_n)
    cv_loader = AudioDataLoader(cv_dataset,
                                batch_size=1,
                                num_workers=args.num_workers,
                                LFR_m=args.LFR_m,
                                LFR_n=args.LFR_n)
    # load dictionary and generate char_list, sos_id, eos_id
    char_list, sos_id, eos_id = process_dict(args.dict)
    vocab_size = len(char_list)
    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}
    # model
    encoder = Encoder(args.d_input * args.LFR_m,
                      args.n_layers_enc,
                      args.n_head,
                      args.d_k,
                      args.d_v,
                      args.d_model,
                      args.d_inner,
                      dropout=args.dropout,
                      pe_maxlen=args.pe_maxlen)
    decoder = Decoder(
        sos_id,
        eos_id,
        vocab_size,
        args.d_word_vec,
        args.n_layers_dec,
        args.n_head,
        args.d_k,
        args.d_v,
        args.d_model,
        args.d_inner,
        dropout=args.dropout,
        tgt_emb_prj_weight_sharing=args.tgt_emb_prj_weight_sharing,
        pe_maxlen=args.pe_maxlen)
    model = Transformer(encoder, decoder)
    print(model)
    model.cuda()
    # optimizer
    model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])
    optimizier = TransformerOptimizer(
        torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        args.k, args.d_model, args.warmup_steps)

    # solver
    solver = Solver(data, model, optimizier, args)
    solver.train()
コード例 #20
0
def iters_schedule_grid_search(model,
                               config,
                               n_iter=6,
                               betas_range=(1e-6, 1e-2),
                               test_batch_size=2,
                               step=1,
                               path_to_store_schedule=None,
                               save_stats_for_grid=True,
                               verbose=True,
                               n_jobs=1):
    """
    Performs grid search for 6 iterations schedule. Run it only on GPU and only for a small number of iterations!
    :param model (torch.nn.Module): WaveGrad model
    :param config (ConfigWrapper): model configuration
    :param n_iter (int, optional): number of iterations to search for
    :param test_batch_size (int, optional): number of one second samples to be tested grid sets on
    :param path_to_store_schedule (str, optional): path to store stats. If not specified, then it will no be saved and would be just returned.
    :param save_stats_for_grid (str, optional): flag to save stats for whole grid or not
    :param verbose (bool, optional): output all the process
    :param n_jobs(int, optional): number of parallel threads to use
    :return betas (list): list of betas, which gives the lowest log10-mel-spectrogram absolute error
    :return stats (dict): dict of type {betas: loss} for the whole grid
    """
    device = next(model.parameters()).device
    if 'cpu' in str(device):
        show_message('WARNING: running grid search on CPU will be slow.')

    show_message('Initializing betas grid...', verbose=verbose)
    grid = generate_betas_grid(n_iter, betas_range, verbose=verbose)[::step]

    show_message('Initializing utils...', verbose=verbose)
    mel_fn = MelSpectrogramFixed(sample_rate=config.data_config.sample_rate,
                                 n_fft=config.data_config.n_fft,
                                 win_length=config.data_config.win_length,
                                 hop_length=config.data_config.hop_length,
                                 f_min=config.data_config.f_min,
                                 f_max=config.data_config.f_max,
                                 n_mels=config.data_config.n_mels,
                                 window_fn=torch.hann_window).to(device)
    dataset = AudioDataset(config, training=True)
    idx = np.random.choice(range(len(dataset)),
                           size=test_batch_size,
                           replace=False)
    test_batch = torch.stack([dataset[i] for i in idx]).to(device)
    test_mels = mel_fn(test_batch)

    show_message('Starting search...', verbose=verbose)
    with ThreadPool(processes=n_jobs) as pool:
        process_fn = partial(_betas_estimate,
                             model=model,
                             mels=test_mels,
                             mel_fn=mel_fn)
        stats = list(tqdm(pool.imap(process_fn, grid), total=len(grid)))
    stats = {i: (grid[i], stats[i]) for i in range(len(stats))}

    if save_stats_for_grid:
        tmp_stats_path = f'{os.path.dirname(path_to_store_schedule)}/{n_iter}stats.pt'
        show_message(
            f'Saving tmp stats for whole grid to `{tmp_stats_path}`...',
            verbose=verbose)
        torch.save(stats, tmp_stats_path)

    best_idx = np.argmin(list([value for _, value in stats.values()]))
    best_betas = grid[best_idx]

    if not isinstance(path_to_store_schedule, type(None)):
        show_message(f'Saving best schedule to `{path_to_store_schedule}`...',
                     verbose=verbose)
        torch.save(best_betas, path_to_store_schedule)

    return best_betas, stats
コード例 #21
0
def main(args):
    # Construct Solver
    # data
    tr_dataset = AudioDataset(args.train_json, args.batch_size, args.maxlen_in,
                              args.maxlen_out)
    cv_dataset = AudioDataset(args.valid_json, args.batch_size, args.maxlen_in,
                              args.maxlen_out)
    tr_loader = AudioDataLoader(tr_dataset,
                                batch_size=1,
                                num_workers=args.num_workers,
                                LFR_m=args.LFR_m,
                                LFR_n=args.LFR_n,
                                align_trun=args.align_trun)
    cv_loader = AudioDataLoader(cv_dataset,
                                batch_size=1,
                                num_workers=args.num_workers,
                                LFR_m=args.LFR_m,
                                LFR_n=args.LFR_n,
                                align_trun=args.align_trun)
    # load dictionary and generate char_list, sos_id, eos_id
    char_list, sos_id, eos_id = process_dict(args.dict)
    args.char_list = char_list
    vocab_size = len(char_list)
    data = {'tr_loader': tr_loader, 'cv_loader': cv_loader}
    # model
    #import pdb
    #pdb.set_trace()
    encoder = Encoder(args.einput * args.LFR_m,
                      args.ehidden,
                      args.elayer,
                      vocab_size,
                      dropout=args.edropout,
                      bidirectional=args.ebidirectional,
                      rnn_type=args.etype)
    decoder = Decoder(vocab_size,
                      args.dembed,
                      sos_id,
                      eos_id,
                      args.dhidden,
                      args.dlayer,
                      args.offset,
                      args.atype,
                      dropout=args.edropout,
                      bidirectional_encoder=args.ebidirectional)
    if args.ebidirectional:
        eprojs = args.ehidden * 2
    else:
        eprojs = args.ehidden
    ctc = CTC(odim=vocab_size, eprojs=eprojs, dropout_rate=args.edropout)
    #lstm_model = Lstmctc.load_model(args.continue_from)

    model = Seq2Seq(encoder, decoder, ctc, args)
    #model_dict = model.state_dict()
    print(model)
    #print(lstm_model)
    #pretrained_dict = torch.load(args.ctc_model)
    #pretrained_dict = {k: v for k, v in pretrained_dict['state_dict'].items() if k in model_dict}
    #pretrained_dict = {(k.replace('lstm','encoder')):v for k, v in pretrained_dict['state_dict'].items() if (k.replace('lstm','encoder')) in model_dict}
    #model_dict.update(pretrained_dict)
    #model.load_state_dict(model_dict)
    #for k,v in model.named_parameters():
    #    if k.startswith("encoder"):
    #        print(k)
    #        v.requires_grad=False
    model.cuda()
    # optimizer
    if args.optimizer == 'sgd':
        optimizier = torch.optim.SGD(model.parameters(),
                                     lr=args.lr,
                                     momentum=args.momentum,
                                     weight_decay=args.l2)
    elif args.optimizer == 'adam':
        optimizier = torch.optim.Adam(model.parameters(),
                                      lr=args.lr,
                                      weight_decay=args.l2)
    else:
        print("Not support optimizer")
        return

    # solver
    ctc = 0
    solver = Solver(data, model, optimizier, args)
    solver.train()
コード例 #22
0
def evaluate(args):
    total_SISNRi = 0
    total_SDRi = 0
    total_cnt = 0
    avg_SISNRiPitNum = 0
    length = torch.ones(1)
    length = length.int()
    numberEsti =[]
    # Load model
    model = ConvTasNet.load_model(args.model_path)
 #   print(model)
    model.eval()
    if args.use_cuda:
        model.cuda(0)

    # Load data
    dataset = AudioDataset(args.data_dir, args.batch_size,
                           sample_rate=args.sample_rate, segment=2)
    data_loader = AudioDataLoader(dataset, batch_size=1, num_workers=2)

    with torch.no_grad():
        for i, (data) in enumerate(data_loader):
            print(i)
            # Get batch data
            padded_mixture, mixture_lengths, padded_source = data
            if args.use_cuda:
                padded_mixture = padded_mixture.cuda(0)
                mixture_lengths = mixture_lengths.cuda(0)
            # Forward
            estimate_source ,s_embed  = model(padded_mixture)  # [B, C, T],#[B,N,K,E] 
          #  print(estimate_source.shape)
           # embid = (model.separator.network[2][7])(padded_mixture)
          #  print(embid)
            '''
            embeddings = s_embed[0].data.cpu().numpy()
            embedding = (embeddings.reshape((1,-1,20)))[0]
            number = sourceNumEsti2(embedding)
            numberEsti.append(number)
            '''
           # print(estimate_source)
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(padded_source, estimate_source, mixture_lengths)
            # Remove padding and flat
            mixture = remove_pad(padded_mixture, mixture_lengths)
            source = remove_pad(padded_source, mixture_lengths)
           # print(max_snr.item())
            # NOTE: use reorder estimate source
            estimate_source = remove_pad(reorder_estimate_source,
                                         mixture_lengths)
           # print((estimate_source[0].shape))
            # for each utterance
            for mix, src_ref, src_est in zip(mixture, source, estimate_source):
                print("Utt", total_cnt + 1)
                # Compute SDRi
                if args.cal_sdr:
                    avg_SDRi = cal_SDRi(src_ref, src_est, mix)
                    total_SDRi += avg_SDRi
                    print("\tSDRi={0:.2f}".format(avg_SDRi))
                # Compute SI-SNRi
                avg_SISNRi = cal_SISNRi(src_ref, src_est, mix)
                #avg_SISNRiPit,a,b = cal_si_snr_with_pit(torch.from_numpy(src_ref), torch.from_numpy(src_est),length)
                print("\tSI-SNRi={0:.2f}".format(avg_SISNRi))
                total_SISNRi += (avg_SISNRi)
                #total_SNRiPitNum += avg_SISNRiPit.numpy()
                total_cnt += 1
            
    if args.cal_sdr:
        print("Average SDR improvement: {0:.2f}".format(total_SDRi / total_cnt))
    print("Average SISNR improvement: {0:.2f}".format(total_SISNRi / total_cnt))
    print("speaker:2 ./ClustertrainTFSE1New/final_paper_2_3_2chobatch6.pth.tar")
   
    return numberEsti
コード例 #23
0
def evaluate(args):
    total_SISNRi = 0
    total_SDRi = 0
    total_cnt = 0

    # Load model
    model = DPTNet(args.N, args.C, args.L, args.H, args.K, args.B)

    if args.use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()

    # model.load_state_dict(torch.load(args.model_path, map_location='cpu'))

    model_info = torch.load(args.model_path)

    state_dict = OrderedDict()
    for k, v in model_info['model_state_dict'].items():
        name = k.replace("module.", "")  # remove 'module.'
        state_dict[name] = v
    model.load_state_dict(state_dict)

    print(model)

    # Load data
    dataset = AudioDataset(args.data_dir,
                           args.batch_size,
                           sample_rate=args.sample_rate,
                           segment=-1)
    data_loader = AudioDataLoader(dataset, batch_size=1, num_workers=2)

    with torch.no_grad():
        for i, (data) in enumerate(data_loader):
            # Get batch data
            padded_mixture, mixture_lengths, padded_source = data
            if args.use_cuda:
                padded_mixture = padded_mixture.cuda()
                mixture_lengths = mixture_lengths.cuda()
                padded_source = padded_source.cuda()
            # Forward
            estimate_source = model(padded_mixture)  # [B, C, T]
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(padded_source, estimate_source, mixture_lengths)
            # Remove padding and flat
            mixture = remove_pad(padded_mixture, mixture_lengths)
            source = remove_pad(padded_source, mixture_lengths)
            # NOTE: use reorder estimate source
            estimate_source = remove_pad(reorder_estimate_source,
                                         mixture_lengths)
            # for each utterance
            for mix, src_ref, src_est in zip(mixture, source, estimate_source):
                print("Utt", total_cnt + 1)
                # Compute SDRi
                if args.cal_sdr:
                    avg_SDRi = cal_SDRi(src_ref, src_est, mix)
                    total_SDRi += avg_SDRi
                    print("\tSDRi={0:.2f}".format(avg_SDRi))
                # Compute SI-SNRi
                avg_SISNRi = cal_SISNRi(src_ref, src_est, mix)
                print("\tSI-SNRi={0:.2f}".format(avg_SISNRi))
                total_SISNRi += avg_SISNRi
                total_cnt += 1
    if args.cal_sdr:
        print("Average SDR improvement: {0:.2f}".format(total_SDRi /
                                                        total_cnt))
    print("Average SISNR improvement: {0:.2f}".format(total_SISNRi /
                                                      total_cnt))
コード例 #24
0
def separate(args):

    # Load model
    model = FaSNet_TAC(enc_dim=args.enc_dim,
                       feature_dim=args.feature_dim,
                       hidden_dim=args.hidden_dim,
                       layer=args.layer,
                       segment_size=args.segment_size,
                       nspk=args.nspk,
                       win_len=args.win_len,
                       context_len=args.context_len,
                       sr=args.sample_rate)

    if args.use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()

    model_info = torch.load(args.model_path)
    try:
        model.load_state_dict(model_info['model_state_dict'])
    except KeyError:
        state_dict = OrderedDict()
        for k, v in model_info['model_state_dict'].items():
            name = k.replace("module.", "")  # remove 'module.'
            state_dict[name] = v
        model.load_state_dict(state_dict)

    print(model)
    model.eval()

    # Load data
    dataset = AudioDataset('test',
                           batch_size=1,
                           sample_rate=args.sample_rate,
                           nmic=args.mic)
    eval_loader = EvalAudioDataLoader(dataset, batch_size=1, num_workers=8)

    os.makedirs(args.out_dir, exist_ok=True)

    def write(inputs, filename, sr=args.sample_rate):
        sf.write(filename, inputs, sr)  # norm=True)

    with torch.no_grad():
        #t = tqdm(total=len(eval_dataset), mininterval=0.5)
        for i, data in enumerate(eval_loader):

            padded_mixture, mixture_lengths, padded_source = data

            if args.use_cuda:
                padded_mixture = padded_mixture.cuda()
                mixture_lengths = mixture_lengths.cuda()
                padded_source = padded_source.cuda()

            x = torch.rand(2, 6, 32000)
            none_mic = torch.zeros(1).type(x.type())
            # Forward
            estimate_source = model(padded_mixture,
                                    none_mic.long())  # [M, C, T]

            for j in range(estimate_source.size()[0]):

                scs = estimate_source[j].cpu().numpy()

                power = np.sqrt((padded_mixture.cpu().numpy()**2).sum() /
                                len(padded_mixture.cpu().numpy()))
                for k, src in enumerate(scs):
                    this_dir = os.path.join(args.out_dir,
                                            'utt{0}'.format(i + 1))
                    if not os.path.exists(this_dir):
                        os.makedirs(this_dir)
                    source = src * (power / np.sqrt(
                        (src**2).sum() / len(padded_mixture)))
                    write(source,
                          os.path.join(this_dir, 's{0}.wav'.format(k + 1)))
コード例 #25
0
def evaluate(args):
    total_SISNRi = 0
    total_SDRi = 0
    total_cnt = 0

    # Load model

    model = FaSNet_TAC(enc_dim=args.enc_dim,
                       feature_dim=args.feature_dim,
                       hidden_dim=args.hidden_dim,
                       layer=args.layer,
                       segment_size=args.segment_size,
                       nspk=args.nspk,
                       win_len=args.win_len,
                       context_len=args.context_len,
                       sr=args.sample_rate)

    if args.use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()

    # model.load_state_dict(torch.load(args.model_path, map_location='cpu'))

    model_info = torch.load(args.model_path)
    try:
        model.load_state_dict(model_info['model_state_dict'])
    except KeyError:
        state_dict = OrderedDict()
        for k, v in model_info['model_state_dict'].items():
            name = k.replace("module.", "")  # remove 'module.'
            state_dict[name] = v
        model.load_state_dict(state_dict)

    print(model)
    model.eval()

    # Load data
    dataset = AudioDataset('test',
                           batch_size=1,
                           sample_rate=args.sample_rate,
                           nmic=args.mic)
    data_loader = EvalAudioDataLoader(dataset, batch_size=1, num_workers=8)

    sisnr_array = []
    sdr_array = []
    with torch.no_grad():
        for i, (data) in enumerate(data_loader):
            # Get batch data
            padded_mixture, mixture_lengths, padded_source = data

            if args.use_cuda:
                padded_mixture = padded_mixture.cuda()
                mixture_lengths = mixture_lengths.cuda()
                padded_source = padded_source.cuda()

            x = torch.rand(2, 6, 32000)
            none_mic = torch.zeros(1).type(x.type())
            # Forward
            estimate_source = model(padded_mixture,
                                    none_mic.long())  # [M, C, T]


            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(padded_source, estimate_source, mixture_lengths)

            M, _, T = padded_mixture.shape
            mixture_ref = torch.chunk(padded_mixture, args.mic,
                                      dim=1)[0]  #[M, ch, T] -> [M, 1, T]
            mixture_ref = mixture_ref.view(M, T)  #[M, 1, T] -> [M, T]

            mixture = remove_pad(mixture_ref, mixture_lengths)
            source = remove_pad(padded_source, mixture_lengths)
            estimate_source = remove_pad(reorder_estimate_source,
                                         mixture_lengths)

            # for each utterance
            for mix, src_ref, src_est in zip(mixture, source, estimate_source):
                print("Utt", total_cnt + 1)
                # Compute SDRi
                if args.cal_sdr:
                    avg_SDRi = cal_SDRi(src_ref, src_est, mix)
                    total_SDRi += avg_SDRi
                    sdr_array.append(avg_SDRi)
                    print("\tSDRi={0:.2f}".format(avg_SDRi))
                # Compute SI-SNRi
                avg_SISNRi = cal_SISNRi(src_ref, src_est, mix)
                print("\tSI-SNRi={0:.2f}".format(avg_SISNRi))
                total_SISNRi += avg_SISNRi
                sisnr_array.append(avg_SISNRi)
                total_cnt += 1
    if args.cal_sdr:
        print("Average SDR improvement: {0:.2f}".format(total_SDRi /
                                                        total_cnt))

    np.save('sisnr.npy', np.array(sisnr_array))
    np.save('sdr.npy', np.array(sdr_array))
    print("Average SISNR improvement: {0:.2f}".format(total_SISNRi /
                                                      total_cnt))
コード例 #26
0
ファイル: train_new.py プロジェクト: yhgon/WaveGrad
def run(config, args):

    if 'LOCAL_RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        local_rank = int(os.environ['LOCAL_RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
    else:
        local_rank = args.rank
        world_size = args.world_size
    distributed_run = world_size > 1

    torch.manual_seed(args.seed + local_rank)
    np.random.seed(args.seed + local_rank)

    #    if local_rank == 0:
    #        if not os.path.exists(args.output):
    #            os.makedirs(args.output)

    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False

    if distributed_run:
        init_distributed(args, world_size, local_rank)

    device = torch.device('cuda' if args.cuda else 'cpu')

    if local_rank == 0:
        print("start training")
        print("args", args)
        print("config", config)

    #############################################
    # model
    if local_rank == 0:
        print("load model")
    model = WaveGrad(config).cuda()

    # optimizer amp config
    if local_rank == 0:
        print("configure optimizer and amp")
    kw = dict(lr=args.learning_rate,
              betas=(0.9, 0.98),
              eps=1e-9,
              weight_decay=args.weight_decay)

    if args.optimizer == 'adam':
        optimizer = FusedAdam(model.parameters(), **kw)
    elif args.optimizer == 'lamb':
        optimizer = FusedLAMB(model.parameters(), **kw)
    elif args.optimizer == 'pytorch':
        optimizer = torch.optim.Adam(model.parameters(), **kw)
    else:
        raise ValueError

    if args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    if distributed_run:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank,
                                        find_unused_parameters=True)

    start_epoch = [1]
    start_iter = [0]

    ################
    #load checkpoint
    if args.checkpoint_path is not None:
        ch_fpath = args.checkpoint_path
        load_checkpoint(local_rank, model, optimizer, start_epoch, start_iter,
                        config, args.amp, ch_fpath, world_size)

    start_epoch = start_epoch[0]
    total_iter = start_iter[0]

    # dataloader
    ##########################################################
    if local_rank == 0:
        print("load dataset")

    if local_rank == 0:
        print("prepare train dataset")
    train_dataset = AudioDataset(config, training=True)

    # distributed sampler
    if distributed_run:
        train_sampler, shuffle = DistributedSampler(train_dataset), False
    else:
        train_sampler, shuffle = None, True

    train_loader = DataLoader(train_dataset,
                              num_workers=1,
                              shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              pin_memory=False,
                              drop_last=True)

    # ground truth samples

    if local_rank == 0:
        print("prepare test_dataset")
    test_dataset = AudioDataset(config, training=False)
    test_loader = DataLoader(test_dataset, batch_size=1)
    test_batch = test_dataset.sample_test_batch(
        config.training_config.n_samples_to_test)

    # Log ground truth test batch
    if local_rank == 0:
        print("save truth wave and mel")
    mel_fn = MelSpectrogramFixed(sample_rate=config.data_config.sample_rate,
                                 n_fft=config.data_config.n_fft,
                                 win_length=config.data_config.win_length,
                                 hop_length=config.data_config.hop_length,
                                 f_min=config.data_config.f_min,
                                 f_max=config.data_config.f_max,
                                 n_mels=config.data_config.n_mels,
                                 window_fn=torch.hann_window).cuda()

    audios = {
        f'audio_{index}/gt': audio
        for index, audio in enumerate(test_batch)
    }
    specs = {
        f'mel_{index}/gt': mel_fn(audio.cuda()).cpu().squeeze()
        for index, audio in enumerate(test_batch)
    }

    ####### loop start
    #epoch
    iteration = 0
    model.train()
    val_loss = 0.0
    torch.cuda.synchronize()

    if local_rank == 0:
        print("epoch start")
    for epoch in range(start_epoch, args.epochs + 1):
        tic_epoch = time.time()
        epoch_loss = 0.0

        if distributed_run:
            train_loader.sampler.set_epoch(epoch)

        accumulated_steps = 0
        iter_loss = 0
        epoch_iter = 0
        #iteration = 0
        num_iters = len(train_loader) // args.gradient_accumulation_steps

        model.module.set_new_noise_schedule(  # 1000 default
            init=torch.linspace,
            init_kwargs={
                'steps': config.training_config.training_noise_schedule.n_iter,
                'start': config.training_config.training_noise_schedule.betas_range[0],
                'end': config.training_config.training_noise_schedule.betas_range[1]
            }
        )

        for i, batch in enumerate(train_loader):
            tic_iter = time.time()

            old_lr = optimizer.param_groups[0]['lr']
            adjust_learning_rate(iteration, optimizer, args.learning_rate,
                                 args.warmup_steps)
            new_lr = optimizer.param_groups[0]['lr']

            model.zero_grad()
            batch = batch.cuda()
            mels = mel_fn(batch)

            # Training step
            model.zero_grad()
            loss = model.module.compute_loss(mels, batch)

            if args.amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, world_size).item()
            else:
                reduced_loss = loss.item()
        # if np.isnan(reduced_loss):
        #     raise Exception("loss is NaN")

            iter_loss += reduced_loss

            if args.amp:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), args.grad_clip_thresh)
            else:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.grad_clip_thresh)

            optimizer.step()

            toc_iter = time.time()
            dur_iter = toc_iter - tic_iter
            epoch_loss += iter_loss
            iter_size = len(train_loader)
            dur_epoch_est = iter_size * dur_iter
            if local_rank == 0:
                print(
                    "\nepoch {:4d} | iter {:>12d}  {:>3d}/{:3d} | {:3.2f}s/iter est {:4.2f}s/epoch | losses {:>12.6f} {:>12.6f} LR {:e}--> {:e}"
                    .format(epoch, iteration, i, iter_size, dur_iter,
                            dur_epoch_est, iter_loss, grad_norm, old_lr,
                            new_lr),
                    end='')
            iter_loss = 0
            iteration += 1

        # Finished epoch
        toc_epoch = time.time()
        dur_epoch = toc_epoch - tic_epoch
        if local_rank == 0:
            print("for {}item,   {:4.2f}s/epoch  ".format(
                iter_size, dur_epoch))

        # Test step
        if epoch % config.training_config.test_interval == 0:
            model.module.set_new_noise_schedule(  # 50 for default
                init=torch.linspace,
                init_kwargs={
                    'steps': config.training_config.test_noise_schedule.n_iter,
                    'start': config.training_config.test_noise_schedule.betas_range[0],
                    'end': config.training_config.test_noise_schedule.betas_range[1]
                } )

        if (epoch % args.epochs_per_checkpoint == 0):
            ch_path = os.path.join(args.output,
                                   "WaveGrad_ch_{:d}.pt".format(epoch))
            save_checkpoint(local_rank, model, optimizer, epoch, iteration,
                            config, args.amp, ch_path)
コード例 #27
0
if train_audio_data.is_empty():
    train_audio_data.add(
        glob.glob(os.path.join(args.dataset_dir, "train", "*.*")))
test_audio_data = AudioData(os.path.join(args.preprocessed_dataset_dir,
                                         "audio_test.hdf5"),
                            sr=args.sr,
                            channels=1)
if test_audio_data.is_empty():
    test_audio_data.add(
        glob.glob(os.path.join(args.dataset_dir, "test", "*.*")))

transform = lambda x: audio_to_onehot(x, model.output_size, NUM_CLASSES)
train_data = AudioDataset(train_audio_data,
                          input_size=model.output_size,
                          context_front=model.input_size - model.output_size +
                          1,
                          hop_size=40000,
                          random_hops=True,
                          audio_transform=transform)
print('Training dataset has ' + str(len(train_data)) + ' items')

# TRAINING
trainer = Trainer(model=model,
                  lr=args.lr,
                  eps=args.eps,
                  snapshot_folder=args.snapshot_dir,
                  logger=writer,
                  dtype=dtype,
                  ltype=ltype,
                  gradient_clipping=args.clip,
                  cuda=args.cuda,
コード例 #28
0
ファイル: train.py プロジェクト: Vladimetr/docs
def train(model_dir=None,
          params='default',
          data_dir='data',
          epochs=15,
          batch_size=500,
          retrain=None,
          train_steps=None,
          test_steps=None,
          debug_mode=False):
    """
    :param model_dir: куда сохранять результаты обучения (при debug_mode=False)
                        if None, model_dir = date_time
    :param params: dict with train and feature params.
                    if params == 'default' take params from params.py
    :param data_dir: dir with: npy/ , data.csv
    :param retrain: path/to/model.pt that we need to re-train
    :param train_steps: сколько батчей прогонять в каждой эпохи
                    if None, all batches
    :param test_steps: сколько тестовых батчей прогонять после каждой эпохи
                        if None, all Test Set
    :param debug_mode: if True, without save model, summary and logs
    """

    # get train params
    if params == 'default':
        params = parametres  # see params.py

    if not debug_mode:
        if not model_dir:
            # create model_dir
            model_dir = datetime.now().strftime("%b%d-%H:%M_run")
            if retrain:
                model_dir = model_dir.replace('run', 'retrain')
        os.makedirs(os.path.join(model_dir, 'saves'))
        print('Model will store in: {}'.format(model_dir), flush=True)
        # -model_dir/saves
        # -model_dir/train.log
        # -model_dir/test.log
        # -model_dir/test.csv
        # -model_dir/train.csv
    else:
        print('Debug mode. No saves and no logs')

    # logging
    if not debug_mode:
        logfile = os.path.join(model_dir, 'train.log')
        print('\nTrain logs to: {}\n'.format(log_file), flush=True)
    else:
        logfile = None  # logs to console

    if logging.getLogger().hasHandlers():  # if already logger exists
        change_logger(logging, logfile)
    else:
        logging.basicConfig(filename=logfile,
                            format="%(message)s",
                            level=logging.INFO)

    # info about parametres
    logging.info('Parametres:\n {}\n'.format(params))

    # split train and test Sets
    logging.info('Split train and test Sets...')
    train_csv, test_csv = split_train_test(data_dir, model_dir)

    # load train data
    train = AudioDataset(train_csv, data_dir, params)
    input_shape = train.get_input_shape()
    logging.info('Input shape: {}\n'.format(input_shape))
    sampler1 = BucketingSampler(train, batch_size)
    train = DataLoaderCuda(train,
                           collate_fn=collate_audio,
                           batch_sampler=sampler1)

    # load test data
    test = AudioDataset(test_csv, data_dir, params)
    sampler2 = BucketingSampler(test, batch_size)
    test = DataLoaderCuda(test,
                          collate_fn=collate_audio,
                          batch_sampler=sampler2)

    # init model
    model = model_init(params,
                       train=True,
                       model_path=retrain,
                       use_cuda=True,
                       logger=logging)

    # select optimizer
    if params['opt'] == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=lr,
                                     weight_decay=params['weight_decay'])
    else:
        raise Exception('No optimizer: {}'.format(params['opt']))

    # reduce learning rate every 2 epochs
    scheduler = StepLR(optimizer, step_size=params['lr_reduce_ep'], gamma=0.1)

    # summary writer
    if not debug_mode:
        log_dir = os.path.join(params['logdir'], model_dir)
        writer_train = SummaryWriter(log_dir=os.path.join(log_dir, 'train'))
        writer_test = SummaryWriter(log_dir=os.path.join(log_dir, 'test'))
        writer_train.add_graph(model, torch.rand(1, *input_shape))
        logging.info('Logs for this model restored at {}'.format(log_dir))
    else:
        writer_test = writer_train = None

    loss = torch.nn.CrossEntropyLoss()

    # n batches
    n_train = train_steps if train_step else len(train)
    n_test = test_steps if test_steps else len(test)
    k = round(n_train / n_test)

    # train and test step
    train_step, test_step = 0, 0

    # best_metric init
    best_loss = 1000

    for ep in range(1, epochs + 1):
        logging.info('\n-------------- {} epoch --------------'.format(ep))
        print('{}/{} Epoch...'.format(ep, epochs))

        model.train()
        train.shuffle(ep)
        for i, (x, target) in enumerate(train):
            optimizer.zero_grad()  # обнуление предыдущих градиентов

            logits, probs = model(x)
            # logits - before activation (for loss)
            # probs - after activation   (for acc)

            # CrossEntropy loss
            output = loss(logits, target)  # is graph (for backward)
            loss_value = output.item()  # is float32

            # in case of learning crash
            if tensor.isnan(loss_value) or math_isnan(loss_value):
                message = 'Loss is nan on {} train step. Learning crash!'.format(
                    train_step)
                logging.info(message)
                print(message)
                return

            # accuracy
            acc_value = accuracy(probs, target)

            # summary
            if not debug_mode and train_step % k == 0:
                writer_train.add_scalar('Loss/steps', loss_value, train_step)
                writer_train.add_scalar('Accuracy/steps', acc_value,
                                        train_step)

            # обратное распр-е ошибки.
            # для каждого параметра модели w считает w.grad
            # здесь НЕ обновляются веса!
            output.backward()

            clip_grad_norm_(model.parameters(),
                            params['grad_norm'])  # prevent exploding gradient

            # здесь обновление весов
            # w_new = w_old - lr * w.grad
            optimizer.step()

            logging.info('| Epoch {}: {}/{} | Loss {:.3f} | Acc {:.2f}'.format(
                ep, i + 1, n_train, loss_value, acc_value))

            train_step += 1

            # interrupt
            if train_steps and i + 1 == train_steps:
                break

        scheduler.step()
        new_lr = float(optimizer.param_groups[0]['lr'])
        logging.info('Updated learning rate: {}'.format(new_lr))

        # saving
        # model_dir/saves/ep_1.pt
        save_name = os.path.join(model_dir, 'saves', 'ep_{}.pt'.format(ep))
        if not debug_mode:
            save_weights(model, save_name, train_step)

        logging.info('\n------------- Test ---------------')
        # test logger setup
        if not debug_mode:
            test_logfile = os.path.join(model_dir, 'test.log')
            change_logger(logging, test_logfile)
            logging.info('Test results to: {}'.format(test_logfile))

        avg_metrics = test(model=model,
                           model_path=save_name,
                           params=params,
                           data_test=test,
                           data_dir=data_dir,
                           test_csv=test_csv,
                           writer=writer_test,
                           step=test_step,
                           batch_size=batch_size,
                           total_steps=test_steps,
                           use_tb=not debug_mode,
                           logfile=logging)
        message = ''
        for k, v in avg_metrics:
            message += '{}: {}\n'.format(k, v)

        # check whether it's the best metrics
        if avg_metrics['loss'] < best_loss:
            best_loss = avg_metrics['loss']
            message = 'New best results'
            logging.info(message)
            print(message)

    if not debug_mode:
        writer_train.close()
        writer_test.close()
コード例 #29
0
ファイル: train.py プロジェクト: janvainer/WaveGrad
def run_training(rank, config, args):
    if args.n_gpus > 1:
        init_distributed(rank, args.n_gpus, config.dist_config)
        torch.cuda.set_device(f'cuda:{rank}')

    show_message('Initializing logger...', verbose=args.verbose, rank=rank)
    logger = Logger(config, rank=rank)

    show_message('Initializing model...', verbose=args.verbose, rank=rank)
    model = WaveGrad(config).cuda()
    show_message(f'Number of WaveGrad parameters: {model.nparams}',
                 verbose=args.verbose,
                 rank=rank)
    mel_fn = MelSpectrogramFixed(**config.data_config).cuda()

    show_message('Initializing optimizer, scheduler and losses...',
                 verbose=args.verbose,
                 rank=rank)
    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=config.training_config.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=config.training_config.scheduler_step_size,
        gamma=config.training_config.scheduler_gamma)
    if config.training_config.use_fp16:
        scaler = torch.cuda.amp.GradScaler()

    show_message('Initializing data loaders...',
                 verbose=args.verbose,
                 rank=rank)
    train_dataset = AudioDataset(config, training=True)
    train_sampler = DistributedSampler(
        train_dataset) if args.n_gpus > 1 else None
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config.training_config.batch_size,
                                  sampler=train_sampler,
                                  drop_last=True)

    if rank == 0:
        test_dataset = AudioDataset(config, training=False)
        test_dataloader = DataLoader(test_dataset, batch_size=1)
        test_batch = test_dataset.sample_test_batch(
            config.training_config.n_samples_to_test)

    if config.training_config.continue_training:
        show_message('Loading latest checkpoint to continue training...',
                     verbose=args.verbose,
                     rank=rank)
        model, optimizer, iteration = logger.load_latest_checkpoint(
            model, optimizer)
        epoch_size = len(train_dataset) // config.training_config.batch_size
        epoch_start = iteration // epoch_size
    else:
        iteration = 0
        epoch_start = 0

    # Log ground truth test batch
    if rank == 0:
        audios = {
            f'audio_{index}/gt': audio
            for index, audio in enumerate(test_batch)
        }
        logger.log_audios(0, audios)
        specs = {
            f'mel_{index}/gt': mel_fn(audio.cuda()).cpu().squeeze()
            for index, audio in enumerate(test_batch)
        }
        logger.log_specs(0, specs)

    if args.n_gpus > 1:
        model = torch.nn.parallel.DistributedDataParallel(model,
                                                          device_ids=[rank])
        show_message(f'INITIALIZATION IS DONE ON RANK {rank}.')

    show_message('Start training...', verbose=args.verbose, rank=rank)
    try:
        for epoch in range(epoch_start, config.training_config.n_epoch):
            # Training step
            model.train()
            (model
             if args.n_gpus == 1 else model.module).set_new_noise_schedule(
                 init=torch.linspace,
                 init_kwargs={
                     'steps':
                     config.training_config.training_noise_schedule.n_iter,
                     'start':
                     config.training_config.training_noise_schedule.
                     betas_range[0],
                     'end':
                     config.training_config.training_noise_schedule.
                     betas_range[1]
                 })
            for batch in (
                tqdm(train_dataloader, leave=False) \
                if args.verbose and rank == 0 else train_dataloader
            ):
                model.zero_grad()

                batch = batch.cuda()
                mels = mel_fn(batch)

                if config.training_config.use_fp16:
                    with torch.cuda.amp.autocast():
                        loss = (model if args.n_gpus == 1 else
                                model.module).compute_loss(mels, batch)
                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                else:
                    loss = (model if args.n_gpus == 1 else
                            model.module).compute_loss(mels, batch)
                    loss.backward()

                grad_norm = torch.nn.utils.clip_grad_norm_(
                    parameters=model.parameters(),
                    max_norm=config.training_config.grad_clip_threshold)

                if config.training_config.use_fp16:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()

                loss_stats = {
                    'total_loss': loss.item(),
                    'grad_norm': grad_norm.item()
                }
                logger.log_training(iteration, loss_stats, verbose=False)

                iteration += 1

            # Test step after epoch on rank==0 GPU
            if epoch % config.training_config.test_interval == 0 and rank == 0:
                model.eval()
                (model
                 if args.n_gpus == 1 else model.module).set_new_noise_schedule(
                     init=torch.linspace,
                     init_kwargs={
                         'steps':
                         config.training_config.test_noise_schedule.n_iter,
                         'start':
                         config.training_config.test_noise_schedule.
                         betas_range[0],
                         'end':
                         config.training_config.test_noise_schedule.
                         betas_range[1]
                     })
                with torch.no_grad():
                    # Calculate test set loss
                    test_loss = 0
                    for i, batch in enumerate(
                        tqdm(test_dataloader) \
                        if args.verbose and rank == 0 else test_dataloader
                    ):
                        batch = batch.cuda()
                        mels = mel_fn(batch)
                        test_loss_ = (model if args.n_gpus == 1 else
                                      model.module).compute_loss(mels, batch)
                        test_loss += test_loss_
                    test_loss /= (i + 1)
                    loss_stats = {'total_loss': test_loss.item()}

                    # Restore random batch from test dataset
                    audios = {}
                    specs = {}
                    test_l1_loss = 0
                    test_l1_spec_loss = 0
                    average_rtf = 0

                    for index, test_sample in enumerate(test_batch):
                        test_sample = test_sample[None].cuda()
                        test_mel = mel_fn(test_sample.cuda())

                        start = datetime.now()
                        y_0_hat = (model if args.n_gpus == 1 else
                                   model.module).forward(
                                       test_mel,
                                       store_intermediate_states=False)
                        y_0_hat_mel = mel_fn(y_0_hat)
                        end = datetime.now()
                        generation_time = (end - start).total_seconds()
                        average_rtf += compute_rtf(
                            y_0_hat, generation_time,
                            config.data_config.sample_rate)

                        test_l1_loss += torch.nn.L1Loss()(y_0_hat,
                                                          test_sample).item()
                        test_l1_spec_loss += torch.nn.L1Loss()(
                            y_0_hat_mel, test_mel).item()

                        audios[f'audio_{index}/predicted'] = y_0_hat.cpu(
                        ).squeeze()
                        specs[f'mel_{index}/predicted'] = y_0_hat_mel.cpu(
                        ).squeeze()

                    average_rtf /= len(test_batch)
                    show_message(f'Device: GPU. average_rtf={average_rtf}',
                                 verbose=args.verbose)

                    test_l1_loss /= len(test_batch)
                    loss_stats['l1_test_batch_loss'] = test_l1_loss
                    test_l1_spec_loss /= len(test_batch)
                    loss_stats['l1_spec_test_batch_loss'] = test_l1_spec_loss

                    logger.log_test(iteration,
                                    loss_stats,
                                    verbose=args.verbose)
                    logger.log_audios(iteration, audios)
                    logger.log_specs(iteration, specs)

                logger.save_checkpoint(
                    iteration, model if args.n_gpus == 1 else model.module,
                    optimizer)
            if epoch % (epoch // 10 + 1) == 0:
                scheduler.step()
    except KeyboardInterrupt:
        print('KeyboardInterrupt: training has been stopped.')
        cleanup()
        return
コード例 #30
0
ファイル: test.py プロジェクト: Vladimetr/docs
def test(model=None,
         model_path=None,
         params='default',
         data_test=None,
         data_dir='data',
         test_csv='data/data_test.csv',
         writer=None,
         step=0,
         batch_size=50,
         steps=None,
         use_tb=False,
         logger=False):
    """
    run model on the Test Dataset
    :param model: torch model. If None it will be init from models.py
    :param params: dict with all required params.
                    if 'default' it will be load from params.py
    :param data_test: torch DataLoader. if None it will be load from test_csv
    :param model_path: path/to/model.pt. if None it will be load from params['restore']
    :param data_dir: path/to/npy/
    :param test_csv: path/to/data_test.csv
    :param step: с какого шага ввести SummaryWriter
    :param steps: how many test batches to calc
    :param use_tb: save summary graph of not
    :param logger: True: print to model_path.log
                   False: print to console
                   logger object
    """
    if params == 'default':
        params = parametres

    if model_path is None:
        model_path = params['restore']
        assert model_path is not None, 'if default params used, model .pt must be defined'

    # logging
    if isinstance(logger, bool):
        if logger:
            # ../saves/ep_20_test.log
            logfile = model_name + '_test.log'
            print('\nlogs to: {}\n'.format(log_file))
        else:
            logfile = None
        logging.basicConfig(filename=logfile,
                            format="%(message)s",
                            level=logging.INFO)
        logger = logging.getLogger
    elif isinstance(logger, module):
        # logger already exists
        pass
    else:
        raise Exception('logger must be bool or module')

    # log info about data
    logger.info('Info about data: {} \n -data_dir: {}\n'.format(
        test_csv, data_dir))

    # log info about features
    logger.info('Info about features:\n -{}'.format())

    # test data
    if data_test:
        test = data_test
    else:
        test = AudioDataset(test_csv, data_dir)
        sampler = BucketingSampler(test, batch_size=batch_size)
        test = DataLoaderCuda(test,
                              collate_fn=collate_audio,
                              batch_sampler=sampler)

    # summary writer
    if use_tb:
        if writer is None:
            summary_dir = os.path.join('dev/test_logs', model_name)
            logging.info('Summary writer: {}\n'.format(summary_dir))
            writer = SummaryWriter(log_dir=summary_dir, purge_step=step)
            close_tb = True
        else:
            # writer is already exists
            close_tb = False

    # init
    if model is None:
        model = model_init(params,
                           model_path=model_path,
                           train=True,
                           use_cuda=True,
                           logger=logging)

    loss = torch.nn.CrossEntropyLoss()

    n_test = len(test)
    test.shuffle(43)

    sum_loss = 0
    metrics = Metrics(acc=True, another_metrics=False)
    for x, target in test:
        with torch.no_grad():
            logits, probs = model(x)

        # loss
        loss_value = loss(logits, target).item()
        sum_loss += loss_value

        metrics(probs, target)

        # summary
        if plot:
            writer.add_scalar('Loss/steps', loss_value, step)
            writer.add_scalar('Accuracy/steps', acc_value, step)

        logger.info('{}/{}: Test loss {:.3f} | Test acc {:.2f}'.format(
            step + 1, n_test, loss_value, acc_value))

        # interrupt
        if total_steps and step + 1 == total_steps:
            break

        step += 1

    # get average metrics
    avg_loss = sum_loss / n_test
    avg_metrics = metrics.get_avg()
    avg_metrics['loss'] = avg_loss
    message = ''
    for k, v in avg_metrics:
        message += '{}: {}\n'.format(k, v)

    # Summary
    logger.info('{:-^10}'.format('Average Metrics'))
    logger.info(message)

    if use_tb and close_tb:
        writer.close()

    return avg_metrics