コード例 #1
0
ファイル: inference.py プロジェクト: dodoproptit99/WaveGrad
def get_mel(config, model):
    mel_fn = MelSpectrogramFixed(sample_rate=config.data_config.sample_rate,
                                 n_fft=config.data_config.n_fft,
                                 win_length=config.data_config.win_length,
                                 hop_length=config.data_config.hop_length,
                                 f_min=config.data_config.f_min,
                                 f_max=config.data_config.f_max,
                                 n_mels=config.data_config.n_mels,
                                 window_fn=torch.hann_window).cuda()

    dataset = AudioDataset(config, training=False)
    test_batch = dataset.sample_test_batch(1)

    # n_iter = 25
    # path_to_store_schedule = f'schedules/default/{n_iter}iters.pt'

    # iters_best_schedule, stats = iters_schedule_grid_search(
    #     model, config,
    #     n_iter=n_iter,
    #     betas_range=(1e-06, 0.01),
    #     test_batch_size=1, step=1,
    #     path_to_store_schedule=path_to_store_schedule,
    #     save_stats_for_grid=True,
    #     verbose=True, n_jobs=4
    # )

    i = 0
    for test in tqdm(test_batch):
        mel = mel_fn(test[None].cuda())
        start = datetime.now()
        t = time()
        outputs = model.forward(mel, store_intermediate_states=False)
        end = datetime.now()
        print("Time infer: ", str(time() - t))
        outputs = outputs.cpu().squeeze()
        save_path = str(i) + '.wav'
        i += 1
        torchaudio.save(save_path,
                        outputs,
                        sample_rate=config.data_config.sample_rate)
        inference_time = (end - start).total_seconds()
        rtf = compute_rtf(outputs,
                          inference_time,
                          sample_rate=config.data_config.sample_rate)
        show_message(f'Done. RTF estimate:{np.std(rtf)}')
コード例 #2
0
ファイル: train_org.py プロジェクト: yhgon/WaveGrad
def run(config, args):
    show_message('Initializing logger...', verbose=args.verbose)
    logger = Logger(config)

    show_message('Initializing model...', verbose=args.verbose)
    model = WaveGrad(config).cuda()
    show_message(f'Number of parameters: {model.nparams}',
                 verbose=args.verbose)
    mel_fn = MelSpectrogramFixed(sample_rate=config.data_config.sample_rate,
                                 n_fft=config.data_config.n_fft,
                                 win_length=config.data_config.win_length,
                                 hop_length=config.data_config.hop_length,
                                 f_min=config.data_config.f_min,
                                 f_max=config.data_config.f_max,
                                 n_mels=config.data_config.n_mels,
                                 window_fn=torch.hann_window).cuda()

    show_message('Initializing optimizer, scheduler and losses...',
                 verbose=args.verbose)
    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=config.training_config.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=config.training_config.scheduler_step_size,
        gamma=config.training_config.scheduler_gamma)

    show_message('Initializing data loaders...', verbose=args.verbose)
    train_dataset = AudioDataset(config, training=True)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config.training_config.batch_size,
                                  drop_last=True)
    test_dataset = AudioDataset(config, training=False)
    test_dataloader = DataLoader(test_dataset, batch_size=1)
    test_batch = test_dataset.sample_test_batch(
        config.training_config.n_samples_to_test)

    if config.training_config.continue_training:
        show_message('Loading latest checkpoint to continue training...',
                     verbose=args.verbose)
        model, optimizer, iteration = logger.load_latest_checkpoint(
            model, optimizer)
        epoch_size = len(train_dataset) // config.training_config.batch_size
        epoch_start = iteration // epoch_size
    else:
        iteration = 0
        epoch_start = 0

    # Log ground truth test batch
    audios = {
        f'audio_{index}/gt': audio
        for index, audio in enumerate(test_batch)
    }
    logger.log_audios(0, audios)
    specs = {
        f'mel_{index}/gt': mel_fn(audio.cuda()).cpu().squeeze()
        for index, audio in enumerate(test_batch)
    }
    logger.log_specs(0, specs)

    show_message('Start training...', verbose=args.verbose)
    try:
        for epoch in range(epoch_start, config.training_config.n_epoch):
            # Training step
            model.set_new_noise_schedule(
                init=torch.linspace,
                init_kwargs={
                    'steps':
                    config.training_config.training_noise_schedule.n_iter,
                    'start':
                    config.training_config.training_noise_schedule.
                    betas_range[0],
                    'end':
                    config.training_config.training_noise_schedule.
                    betas_range[1]
                })
            for i, batch in enumerate(train_dataloader):
                tic_iter = time.time()
                batch = batch.cuda()
                mels = mel_fn(batch)

                # Training step
                model.zero_grad()
                loss = model.compute_loss(mels, batch)
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    parameters=model.parameters(),
                    max_norm=config.training_config.grad_clip_threshold)
                optimizer.step()
                toc_iter = time.time()
                dur_iter = toc_iter - tic_iter
                dur_iter = np.round(dur_iter, 4)
                loss_stats = {
                    'total_loss': loss.item(),
                    'grad_norm': grad_norm.item()
                }
                iter_size = len(train_dataloader)
                logger.log_training(epoch,
                                    iteration,
                                    i,
                                    iter_size,
                                    loss_stats,
                                    dur_iter,
                                    int(dur_iter * iter_size),
                                    verbose=args.verbose)

                iteration += 1

            # Test step
            if epoch % config.training_config.test_interval == 0:
                model.set_new_noise_schedule(
                    init=torch.linspace,
                    init_kwargs={
                        'steps':
                        config.training_config.test_noise_schedule.n_iter,
                        'start':
                        config.training_config.test_noise_schedule.
                        betas_range[0],
                        'end':
                        config.training_config.test_noise_schedule.
                        betas_range[1]
                    })
                with torch.no_grad():
                    # Calculate test set loss
                    test_loss = 0
                    for i, batch in enumerate(test_dataloader):
                        batch = batch.cuda()
                        mels = mel_fn(batch)
                        test_loss_ = model.compute_loss(mels, batch)
                        test_loss += test_loss_
                    test_loss /= (i + 1)
                    loss_stats = {'total_loss': test_loss.item()}

                    # Restore random batch from test dataset
                    audios = {}
                    specs = {}
                    test_l1_loss = 0
                    test_l1_spec_loss = 0
                    average_rtf = 0

                    for index, test_sample in enumerate(test_batch):
                        test_sample = test_sample[None].cuda()
                        test_mel = mel_fn(test_sample.cuda())

                        start = datetime.now()
                        y_0_hat = model.forward(
                            test_mel, store_intermediate_states=False)
                        y_0_hat_mel = mel_fn(y_0_hat)
                        end = datetime.now()
                        generation_time = (end - start).total_seconds()
                        average_rtf += compute_rtf(
                            y_0_hat, generation_time,
                            config.data_config.sample_rate)

                        test_l1_loss += torch.nn.L1Loss()(y_0_hat,
                                                          test_sample).item()
                        test_l1_spec_loss += torch.nn.L1Loss()(
                            y_0_hat_mel, test_mel).item()

                        audios[f'audio_{index}/predicted'] = y_0_hat.cpu(
                        ).squeeze()
                        specs[f'mel_{index}/predicted'] = y_0_hat_mel.cpu(
                        ).squeeze()

                    average_rtf /= len(test_batch)
                    show_message(f'Device: GPU. average_rtf={average_rtf}',
                                 verbose=args.verbose)

                    test_l1_loss /= len(test_batch)
                    loss_stats['l1_test_batch_loss'] = test_l1_loss
                    test_l1_spec_loss /= len(test_batch)
                    loss_stats['l1_spec_test_batch_loss'] = test_l1_spec_loss

                    logger.log_test(epoch, loss_stats, verbose=args.verbose)
                    #logger.log_audios(epoch, audios)
                    #logger.log_specs(epoch, specs)

                logger.save_checkpoint(iteration, model, optimizer)
            if epoch % (epoch // 10 + 1) == 0:
                scheduler.step()
    except KeyboardInterrupt:
        print('KeyboardInterrupt: training has been stopped.')
        return
コード例 #3
0
ファイル: train_new.py プロジェクト: yhgon/WaveGrad
def run(config, args):

    if 'LOCAL_RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        local_rank = int(os.environ['LOCAL_RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
    else:
        local_rank = args.rank
        world_size = args.world_size
    distributed_run = world_size > 1

    torch.manual_seed(args.seed + local_rank)
    np.random.seed(args.seed + local_rank)

    #    if local_rank == 0:
    #        if not os.path.exists(args.output):
    #            os.makedirs(args.output)

    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False

    if distributed_run:
        init_distributed(args, world_size, local_rank)

    device = torch.device('cuda' if args.cuda else 'cpu')

    if local_rank == 0:
        print("start training")
        print("args", args)
        print("config", config)

    #############################################
    # model
    if local_rank == 0:
        print("load model")
    model = WaveGrad(config).cuda()

    # optimizer amp config
    if local_rank == 0:
        print("configure optimizer and amp")
    kw = dict(lr=args.learning_rate,
              betas=(0.9, 0.98),
              eps=1e-9,
              weight_decay=args.weight_decay)

    if args.optimizer == 'adam':
        optimizer = FusedAdam(model.parameters(), **kw)
    elif args.optimizer == 'lamb':
        optimizer = FusedLAMB(model.parameters(), **kw)
    elif args.optimizer == 'pytorch':
        optimizer = torch.optim.Adam(model.parameters(), **kw)
    else:
        raise ValueError

    if args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    if distributed_run:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank,
                                        find_unused_parameters=True)

    start_epoch = [1]
    start_iter = [0]

    ################
    #load checkpoint
    if args.checkpoint_path is not None:
        ch_fpath = args.checkpoint_path
        load_checkpoint(local_rank, model, optimizer, start_epoch, start_iter,
                        config, args.amp, ch_fpath, world_size)

    start_epoch = start_epoch[0]
    total_iter = start_iter[0]

    # dataloader
    ##########################################################
    if local_rank == 0:
        print("load dataset")

    if local_rank == 0:
        print("prepare train dataset")
    train_dataset = AudioDataset(config, training=True)

    # distributed sampler
    if distributed_run:
        train_sampler, shuffle = DistributedSampler(train_dataset), False
    else:
        train_sampler, shuffle = None, True

    train_loader = DataLoader(train_dataset,
                              num_workers=1,
                              shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              pin_memory=False,
                              drop_last=True)

    # ground truth samples

    if local_rank == 0:
        print("prepare test_dataset")
    test_dataset = AudioDataset(config, training=False)
    test_loader = DataLoader(test_dataset, batch_size=1)
    test_batch = test_dataset.sample_test_batch(
        config.training_config.n_samples_to_test)

    # Log ground truth test batch
    if local_rank == 0:
        print("save truth wave and mel")
    mel_fn = MelSpectrogramFixed(sample_rate=config.data_config.sample_rate,
                                 n_fft=config.data_config.n_fft,
                                 win_length=config.data_config.win_length,
                                 hop_length=config.data_config.hop_length,
                                 f_min=config.data_config.f_min,
                                 f_max=config.data_config.f_max,
                                 n_mels=config.data_config.n_mels,
                                 window_fn=torch.hann_window).cuda()

    audios = {
        f'audio_{index}/gt': audio
        for index, audio in enumerate(test_batch)
    }
    specs = {
        f'mel_{index}/gt': mel_fn(audio.cuda()).cpu().squeeze()
        for index, audio in enumerate(test_batch)
    }

    ####### loop start
    #epoch
    iteration = 0
    model.train()
    val_loss = 0.0
    torch.cuda.synchronize()

    if local_rank == 0:
        print("epoch start")
    for epoch in range(start_epoch, args.epochs + 1):
        tic_epoch = time.time()
        epoch_loss = 0.0

        if distributed_run:
            train_loader.sampler.set_epoch(epoch)

        accumulated_steps = 0
        iter_loss = 0
        epoch_iter = 0
        #iteration = 0
        num_iters = len(train_loader) // args.gradient_accumulation_steps

        model.module.set_new_noise_schedule(  # 1000 default
            init=torch.linspace,
            init_kwargs={
                'steps': config.training_config.training_noise_schedule.n_iter,
                'start': config.training_config.training_noise_schedule.betas_range[0],
                'end': config.training_config.training_noise_schedule.betas_range[1]
            }
        )

        for i, batch in enumerate(train_loader):
            tic_iter = time.time()

            old_lr = optimizer.param_groups[0]['lr']
            adjust_learning_rate(iteration, optimizer, args.learning_rate,
                                 args.warmup_steps)
            new_lr = optimizer.param_groups[0]['lr']

            model.zero_grad()
            batch = batch.cuda()
            mels = mel_fn(batch)

            # Training step
            model.zero_grad()
            loss = model.module.compute_loss(mels, batch)

            if args.amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, world_size).item()
            else:
                reduced_loss = loss.item()
        # if np.isnan(reduced_loss):
        #     raise Exception("loss is NaN")

            iter_loss += reduced_loss

            if args.amp:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), args.grad_clip_thresh)
            else:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.grad_clip_thresh)

            optimizer.step()

            toc_iter = time.time()
            dur_iter = toc_iter - tic_iter
            epoch_loss += iter_loss
            iter_size = len(train_loader)
            dur_epoch_est = iter_size * dur_iter
            if local_rank == 0:
                print(
                    "\nepoch {:4d} | iter {:>12d}  {:>3d}/{:3d} | {:3.2f}s/iter est {:4.2f}s/epoch | losses {:>12.6f} {:>12.6f} LR {:e}--> {:e}"
                    .format(epoch, iteration, i, iter_size, dur_iter,
                            dur_epoch_est, iter_loss, grad_norm, old_lr,
                            new_lr),
                    end='')
            iter_loss = 0
            iteration += 1

        # Finished epoch
        toc_epoch = time.time()
        dur_epoch = toc_epoch - tic_epoch
        if local_rank == 0:
            print("for {}item,   {:4.2f}s/epoch  ".format(
                iter_size, dur_epoch))

        # Test step
        if epoch % config.training_config.test_interval == 0:
            model.module.set_new_noise_schedule(  # 50 for default
                init=torch.linspace,
                init_kwargs={
                    'steps': config.training_config.test_noise_schedule.n_iter,
                    'start': config.training_config.test_noise_schedule.betas_range[0],
                    'end': config.training_config.test_noise_schedule.betas_range[1]
                } )

        if (epoch % args.epochs_per_checkpoint == 0):
            ch_path = os.path.join(args.output,
                                   "WaveGrad_ch_{:d}.pt".format(epoch))
            save_checkpoint(local_rank, model, optimizer, epoch, iteration,
                            config, args.amp, ch_path)
コード例 #4
0
ファイル: train.py プロジェクト: janvainer/WaveGrad
def run_training(rank, config, args):
    if args.n_gpus > 1:
        init_distributed(rank, args.n_gpus, config.dist_config)
        torch.cuda.set_device(f'cuda:{rank}')

    show_message('Initializing logger...', verbose=args.verbose, rank=rank)
    logger = Logger(config, rank=rank)

    show_message('Initializing model...', verbose=args.verbose, rank=rank)
    model = WaveGrad(config).cuda()
    show_message(f'Number of WaveGrad parameters: {model.nparams}',
                 verbose=args.verbose,
                 rank=rank)
    mel_fn = MelSpectrogramFixed(**config.data_config).cuda()

    show_message('Initializing optimizer, scheduler and losses...',
                 verbose=args.verbose,
                 rank=rank)
    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=config.training_config.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=config.training_config.scheduler_step_size,
        gamma=config.training_config.scheduler_gamma)
    if config.training_config.use_fp16:
        scaler = torch.cuda.amp.GradScaler()

    show_message('Initializing data loaders...',
                 verbose=args.verbose,
                 rank=rank)
    train_dataset = AudioDataset(config, training=True)
    train_sampler = DistributedSampler(
        train_dataset) if args.n_gpus > 1 else None
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config.training_config.batch_size,
                                  sampler=train_sampler,
                                  drop_last=True)

    if rank == 0:
        test_dataset = AudioDataset(config, training=False)
        test_dataloader = DataLoader(test_dataset, batch_size=1)
        test_batch = test_dataset.sample_test_batch(
            config.training_config.n_samples_to_test)

    if config.training_config.continue_training:
        show_message('Loading latest checkpoint to continue training...',
                     verbose=args.verbose,
                     rank=rank)
        model, optimizer, iteration = logger.load_latest_checkpoint(
            model, optimizer)
        epoch_size = len(train_dataset) // config.training_config.batch_size
        epoch_start = iteration // epoch_size
    else:
        iteration = 0
        epoch_start = 0

    # Log ground truth test batch
    if rank == 0:
        audios = {
            f'audio_{index}/gt': audio
            for index, audio in enumerate(test_batch)
        }
        logger.log_audios(0, audios)
        specs = {
            f'mel_{index}/gt': mel_fn(audio.cuda()).cpu().squeeze()
            for index, audio in enumerate(test_batch)
        }
        logger.log_specs(0, specs)

    if args.n_gpus > 1:
        model = torch.nn.parallel.DistributedDataParallel(model,
                                                          device_ids=[rank])
        show_message(f'INITIALIZATION IS DONE ON RANK {rank}.')

    show_message('Start training...', verbose=args.verbose, rank=rank)
    try:
        for epoch in range(epoch_start, config.training_config.n_epoch):
            # Training step
            model.train()
            (model
             if args.n_gpus == 1 else model.module).set_new_noise_schedule(
                 init=torch.linspace,
                 init_kwargs={
                     'steps':
                     config.training_config.training_noise_schedule.n_iter,
                     'start':
                     config.training_config.training_noise_schedule.
                     betas_range[0],
                     'end':
                     config.training_config.training_noise_schedule.
                     betas_range[1]
                 })
            for batch in (
                tqdm(train_dataloader, leave=False) \
                if args.verbose and rank == 0 else train_dataloader
            ):
                model.zero_grad()

                batch = batch.cuda()
                mels = mel_fn(batch)

                if config.training_config.use_fp16:
                    with torch.cuda.amp.autocast():
                        loss = (model if args.n_gpus == 1 else
                                model.module).compute_loss(mels, batch)
                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                else:
                    loss = (model if args.n_gpus == 1 else
                            model.module).compute_loss(mels, batch)
                    loss.backward()

                grad_norm = torch.nn.utils.clip_grad_norm_(
                    parameters=model.parameters(),
                    max_norm=config.training_config.grad_clip_threshold)

                if config.training_config.use_fp16:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()

                loss_stats = {
                    'total_loss': loss.item(),
                    'grad_norm': grad_norm.item()
                }
                logger.log_training(iteration, loss_stats, verbose=False)

                iteration += 1

            # Test step after epoch on rank==0 GPU
            if epoch % config.training_config.test_interval == 0 and rank == 0:
                model.eval()
                (model
                 if args.n_gpus == 1 else model.module).set_new_noise_schedule(
                     init=torch.linspace,
                     init_kwargs={
                         'steps':
                         config.training_config.test_noise_schedule.n_iter,
                         'start':
                         config.training_config.test_noise_schedule.
                         betas_range[0],
                         'end':
                         config.training_config.test_noise_schedule.
                         betas_range[1]
                     })
                with torch.no_grad():
                    # Calculate test set loss
                    test_loss = 0
                    for i, batch in enumerate(
                        tqdm(test_dataloader) \
                        if args.verbose and rank == 0 else test_dataloader
                    ):
                        batch = batch.cuda()
                        mels = mel_fn(batch)
                        test_loss_ = (model if args.n_gpus == 1 else
                                      model.module).compute_loss(mels, batch)
                        test_loss += test_loss_
                    test_loss /= (i + 1)
                    loss_stats = {'total_loss': test_loss.item()}

                    # Restore random batch from test dataset
                    audios = {}
                    specs = {}
                    test_l1_loss = 0
                    test_l1_spec_loss = 0
                    average_rtf = 0

                    for index, test_sample in enumerate(test_batch):
                        test_sample = test_sample[None].cuda()
                        test_mel = mel_fn(test_sample.cuda())

                        start = datetime.now()
                        y_0_hat = (model if args.n_gpus == 1 else
                                   model.module).forward(
                                       test_mel,
                                       store_intermediate_states=False)
                        y_0_hat_mel = mel_fn(y_0_hat)
                        end = datetime.now()
                        generation_time = (end - start).total_seconds()
                        average_rtf += compute_rtf(
                            y_0_hat, generation_time,
                            config.data_config.sample_rate)

                        test_l1_loss += torch.nn.L1Loss()(y_0_hat,
                                                          test_sample).item()
                        test_l1_spec_loss += torch.nn.L1Loss()(
                            y_0_hat_mel, test_mel).item()

                        audios[f'audio_{index}/predicted'] = y_0_hat.cpu(
                        ).squeeze()
                        specs[f'mel_{index}/predicted'] = y_0_hat_mel.cpu(
                        ).squeeze()

                    average_rtf /= len(test_batch)
                    show_message(f'Device: GPU. average_rtf={average_rtf}',
                                 verbose=args.verbose)

                    test_l1_loss /= len(test_batch)
                    loss_stats['l1_test_batch_loss'] = test_l1_loss
                    test_l1_spec_loss /= len(test_batch)
                    loss_stats['l1_spec_test_batch_loss'] = test_l1_spec_loss

                    logger.log_test(iteration,
                                    loss_stats,
                                    verbose=args.verbose)
                    logger.log_audios(iteration, audios)
                    logger.log_specs(iteration, specs)

                logger.save_checkpoint(
                    iteration, model if args.n_gpus == 1 else model.module,
                    optimizer)
            if epoch % (epoch // 10 + 1) == 0:
                scheduler.step()
    except KeyboardInterrupt:
        print('KeyboardInterrupt: training has been stopped.')
        cleanup()
        return