Ejemplo n.º 1
0
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, early_stopping

from data_loading.data_module import DataModule
from nnunet.nn_unet import NNUnet
from utils.args import get_main_args
from utils.gpu_affinity import set_affinity
from utils.logger import LoggingCallback
from utils.utils import make_empty_dir, set_cuda_devices, verify_ckpt_path

if __name__ == "__main__":
    args = get_main_args()

    if args.profile:
        nvidia_dlprof_pytorch_nvtx.init()
        print("Profiling enabled")

    if args.affinity != "disabled":
        affinity = set_affinity(int(os.getenv("LOCAL_RANK", "0")),
                                args.gpus,
                                mode=args.affinity)

    # Limit number of CPU threads
    os.environ["OMP_NUM_THREADS"] = "1"
    # Set device limit on the current device cudaLimitMaxL2FetchGranularity = 0x05
    _libcudart = ctypes.CDLL("libcudart.so")
    pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
    _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
    _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
    assert pValue.contents.value == 128
Ejemplo n.º 2
0
def main(argv):
    validate_flags()
    torch.manual_seed(FLAGS.seed)

    utils.init_logging(log_path=FLAGS.log_path)
    dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER')

    data_loader_train, data_loader_test = get_data_loaders(FLAGS)

    scaled_lr = FLAGS.lr / FLAGS.loss_scale if FLAGS.amp else FLAGS.lr

    model = create_model()

    optimizer = torch.optim.SGD(model.parameters(), lr=scaled_lr)

    if FLAGS.mode == 'prof-train':
        nvidia_dlprof_pytorch_nvtx.init(enable_function_stack=True)

    if FLAGS.amp and (FLAGS.mode == 'train' or FLAGS.mode == 'prof-train'):
        (model.top_model, model.bottom_model.mlp), optimizer = amp.initialize(
            [model.top_model, model.bottom_model.mlp],
            optimizer,
            opt_level="O2",
            loss_scale=1)
    elif FLAGS.amp:
        model = model.half()

    loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean")

    if FLAGS.mode == 'test':
        loss, auc, test_step_time = evaluate(model, loss_fn, data_loader_test)

        avg_test_throughput = FLAGS.batch_size / test_step_time
        results = {
            'auc': auc,
            'avg_inference_latency': test_step_time,
            'average_test_throughput': avg_test_throughput
        }
        dllogger.log(data=results, step=tuple())

        print(f"Finished testing. Test Loss {loss:.4f}, auc {auc:.4f}")
        return

    if FLAGS.mode == 'inference_benchmark':
        results = {}

        if FLAGS.amp:
            # can use pure FP16 for inference
            model = model.half()

        for batch_size in FLAGS.inference_benchmark_batch_sizes:
            batch_size = int(batch_size)
            FLAGS.test_batch_size = batch_size

            _, benchmark_data_loader = get_data_loaders(FLAGS)

            latencies = inference_benchmark(
                model=model,
                data_loader=benchmark_data_loader,
                num_batches=FLAGS.inference_benchmark_steps)

            print("All inference latencies: {}".format(latencies))

            mean_latency = np.mean(latencies)
            mean_inference_throughput = batch_size / mean_latency
            subresult = {
                f'mean_inference_latency_batch_{batch_size}':
                mean_latency,
                f'mean_inference_throughput_batch_{batch_size}':
                mean_inference_throughput
            }
            results.update(subresult)
        dllogger.log(data=results, step=tuple())

        print(f"Finished inference benchmark.")
        return

    if FLAGS.mode == 'train':
        train(model, loss_fn, optimizer, data_loader_train, data_loader_test,
              scaled_lr)
    if FLAGS.mode == 'prof-train':
        with torch.autograd.profiler.emit_nvtx():
            train(model, loss_fn, optimizer, data_loader_train,
                  data_loader_test, scaled_lr)
Ejemplo n.º 3
0
def main():
    args = parse_args()

    assert (torch.cuda.is_available())
    assert args.prediction_frequency % args.log_frequency == 0

    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    # set up distributed training
    multi_gpu = int(os.environ.get('WORLD_SIZE', 1)) > 1
    if multi_gpu:
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend='nccl', init_method='env://')
        world_size = dist.get_world_size()
        print_once(f'Distributed training with {world_size} GPUs\n')
    else:
        world_size = 1

    torch.manual_seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    random.seed(args.seed + args.local_rank)

    init_log(args)

    cfg = config.load(args.model_config)
    config.apply_config_overrides(cfg, args)

    symbols = helpers.add_ctc_blank(cfg['labels'])

    assert args.grad_accumulation_steps >= 1
    assert args.batch_size % args.grad_accumulation_steps == 0
    batch_size = args.batch_size // args.grad_accumulation_steps

    print_once('Setting up datasets...')
    train_dataset_kw, train_features_kw = config.input(cfg, 'train')
    val_dataset_kw, val_features_kw = config.input(cfg, 'val')

    use_dali = args.dali_device in ('cpu', 'gpu')
    if use_dali:
        assert train_dataset_kw['ignore_offline_speed_perturbation'], \
            "DALI doesn't support offline speed perturbation"

        # pad_to_max_duration is not supported by DALI - have simple padders
        if train_features_kw['pad_to_max_duration']:
            train_feat_proc = BaseFeatures(
                pad_align=train_features_kw['pad_align'],
                pad_to_max_duration=True,
                max_duration=train_features_kw['max_duration'],
                sample_rate=train_features_kw['sample_rate'],
                window_size=train_features_kw['window_size'],
                window_stride=train_features_kw['window_stride'])
            train_features_kw['pad_to_max_duration'] = False
        else:
            train_feat_proc = None

        if val_features_kw['pad_to_max_duration']:
            val_feat_proc = BaseFeatures(
                pad_align=val_features_kw['pad_align'],
                pad_to_max_duration=True,
                max_duration=val_features_kw['max_duration'],
                sample_rate=val_features_kw['sample_rate'],
                window_size=val_features_kw['window_size'],
                window_stride=val_features_kw['window_stride'])
            val_features_kw['pad_to_max_duration'] = False
        else:
            val_feat_proc = None

        train_loader = DaliDataLoader(
            gpu_id=args.local_rank,
            dataset_path=args.dataset_dir,
            config_data=train_dataset_kw,
            config_features=train_features_kw,
            json_names=args.train_manifests,
            batch_size=batch_size,
            grad_accumulation_steps=args.grad_accumulation_steps,
            pipeline_type="train",
            device_type=args.dali_device,
            symbols=symbols)

        val_loader = DaliDataLoader(gpu_id=args.local_rank,
                                    dataset_path=args.dataset_dir,
                                    config_data=val_dataset_kw,
                                    config_features=val_features_kw,
                                    json_names=args.val_manifests,
                                    batch_size=batch_size,
                                    pipeline_type="val",
                                    device_type=args.dali_device,
                                    symbols=symbols)
    else:
        train_dataset_kw, train_features_kw = config.input(cfg, 'train')
        train_dataset = AudioDataset(args.dataset_dir, args.train_manifests,
                                     symbols, **train_dataset_kw)
        train_loader = get_data_loader(train_dataset,
                                       batch_size,
                                       multi_gpu=multi_gpu,
                                       shuffle=True,
                                       num_workers=4)
        train_feat_proc = FilterbankFeatures(**train_features_kw)

        val_dataset_kw, val_features_kw = config.input(cfg, 'val')
        val_dataset = AudioDataset(args.dataset_dir, args.val_manifests,
                                   symbols, **val_dataset_kw)
        val_loader = get_data_loader(val_dataset,
                                     batch_size,
                                     multi_gpu=multi_gpu,
                                     shuffle=False,
                                     num_workers=4,
                                     drop_last=False)
        val_feat_proc = FilterbankFeatures(**val_features_kw)

        dur = train_dataset.duration / 3600
        dur_f = train_dataset.duration_filtered / 3600
        nsampl = len(train_dataset)
        print_once(f'Training samples: {nsampl} ({dur:.1f}h, '
                   f'filtered {dur_f:.1f}h)')

    if train_feat_proc is not None:
        train_feat_proc.cuda()
    if val_feat_proc is not None:
        val_feat_proc.cuda()

    steps_per_epoch = len(train_loader) // args.grad_accumulation_steps

    # set up the model
    model = Jasper(encoder_kw=config.encoder(cfg),
                   decoder_kw=config.decoder(cfg, n_classes=len(symbols)))
    model.cuda()
    ctc_loss = CTCLossNM(n_classes=len(symbols))
    greedy_decoder = GreedyCTCDecoder()

    print_once(f'Model size: {num_weights(model) / 10**6:.1f}M params\n')

    # optimization
    kw = {'lr': args.lr, 'weight_decay': args.weight_decay}
    if args.optimizer == "novograd":
        optimizer = Novograd(model.parameters(), **kw)
    elif args.optimizer == "adamw":
        optimizer = AdamW(model.parameters(), **kw)
    else:
        raise ValueError(f'Invalid optimizer "{args.optimizer}"')

    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    adjust_lr = lambda step, epoch, optimizer: lr_policy(
        step,
        epoch,
        args.lr,
        optimizer,
        steps_per_epoch=steps_per_epoch,
        warmup_epochs=args.warmup_epochs,
        hold_epochs=args.hold_epochs,
        num_epochs=args.epochs,
        policy=args.lr_policy,
        min_lr=args.min_lr,
        exp_gamma=args.lr_exp_gamma)

    if args.ema > 0:
        ema_model = copy.deepcopy(model)
    else:
        ema_model = None

    if multi_gpu:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)
    if args.pyprof:
        pyprof.init(enable_function_stack=True)

    # load checkpoint
    meta = {'best_wer': 10**6, 'start_epoch': 0}
    checkpointer = Checkpointer(args.output_dir, 'Jasper',
                                args.keep_milestones)
    if args.resume:
        args.ckpt = checkpointer.last_checkpoint() or args.ckpt

    if args.ckpt is not None:
        checkpointer.load(args.ckpt, model, ema_model, optimizer, scaler, meta)

    start_epoch = meta['start_epoch']
    best_wer = meta['best_wer']
    epoch = 1
    step = start_epoch * steps_per_epoch + 1

    if args.pyprof:
        torch.autograd.profiler.emit_nvtx().__enter__()
        profiler.start()

    # training loop
    model.train()

    # pre-allocate
    if args.pre_allocate_range is not None:
        n_feats = train_features_kw['n_filt']
        pad_align = train_features_kw['pad_align']
        a, b = args.pre_allocate_range
        for n_frames in range(a, b + pad_align, pad_align):
            print_once(
                f'Pre-allocation ({batch_size}x{n_feats}x{n_frames})...')

            feat = torch.randn(batch_size, n_feats, n_frames, device='cuda')
            feat_lens = torch.ones(batch_size, device='cuda').fill_(n_frames)
            txt = torch.randint(high=len(symbols) - 1,
                                size=(batch_size, 100),
                                device='cuda')
            txt_lens = torch.ones(batch_size, device='cuda').fill_(100)
            with torch.cuda.amp.autocast(enabled=args.amp):
                log_probs, enc_lens = model(feat, feat_lens)
                del feat
                loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
            loss.backward()
            model.zero_grad()
    torch.cuda.empty_cache()

    bmark_stats = BenchmarkStats()

    for epoch in range(start_epoch + 1, args.epochs + 1):
        if multi_gpu and not use_dali:
            train_loader.sampler.set_epoch(epoch)

        epoch_utts = 0
        epoch_loss = 0
        accumulated_batches = 0
        epoch_start_time = time.time()
        epoch_eval_time = 0

        for batch in train_loader:

            if accumulated_batches == 0:
                step_loss = 0
                step_utts = 0
                step_start_time = time.time()

            if use_dali:
                # with DALI, the data is already on GPU
                feat, feat_lens, txt, txt_lens = batch
                if train_feat_proc is not None:
                    feat, feat_lens = train_feat_proc(feat, feat_lens)
            else:
                batch = [t.cuda(non_blocking=True) for t in batch]
                audio, audio_lens, txt, txt_lens = batch
                feat, feat_lens = train_feat_proc(audio, audio_lens)

            # Use context manager to prevent redundant accumulation of gradients
            if (multi_gpu and
                    accumulated_batches + 1 < args.grad_accumulation_steps):
                ctx = model.no_sync()
            else:
                ctx = empty_context()

            with ctx:
                with torch.cuda.amp.autocast(enabled=args.amp):
                    log_probs, enc_lens = model(feat, feat_lens)

                    loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
                    loss /= args.grad_accumulation_steps

                if multi_gpu:
                    reduced_loss = reduce_tensor(loss.data, world_size)
                else:
                    reduced_loss = loss

                if torch.isnan(reduced_loss).any():
                    print_once(f'WARNING: loss is NaN; skipping update')
                    continue
                else:
                    step_loss += reduced_loss.item()
                    step_utts += batch[0].size(0) * world_size
                    epoch_utts += batch[0].size(0) * world_size
                    accumulated_batches += 1

                    scaler.scale(loss).backward()

            if accumulated_batches % args.grad_accumulation_steps == 0:
                epoch_loss += step_loss
                scaler.step(optimizer)
                scaler.update()

                adjust_lr(step, epoch, optimizer)
                optimizer.zero_grad()

                apply_ema(model, ema_model, args.ema)

                if step % args.log_frequency == 0:
                    preds = greedy_decoder(log_probs)
                    wer, pred_utt, ref = greedy_wer(preds, txt, txt_lens,
                                                    symbols)

                    if step % args.prediction_frequency == 0:
                        print_once(f'  Decoded:   {pred_utt[:90]}')
                        print_once(f'  Reference: {ref[:90]}')

                    step_time = time.time() - step_start_time
                    log(
                        (epoch, step % steps_per_epoch
                         or steps_per_epoch, steps_per_epoch), step, 'train', {
                             'loss': step_loss,
                             'wer': 100.0 * wer,
                             'throughput': step_utts / step_time,
                             'took': step_time,
                             'lrate': optimizer.param_groups[0]['lr']
                         })

                step_start_time = time.time()

                if step % args.eval_frequency == 0:
                    tik = time.time()
                    wer = evaluate(epoch, step, val_loader, val_feat_proc,
                                   symbols, model, ema_model, ctc_loss,
                                   greedy_decoder, args.amp, use_dali)

                    if wer < best_wer and epoch >= args.save_best_from:
                        checkpointer.save(model,
                                          ema_model,
                                          optimizer,
                                          scaler,
                                          epoch,
                                          step,
                                          best_wer,
                                          is_best=True)
                        best_wer = wer
                    epoch_eval_time += time.time() - tik

                step += 1
                accumulated_batches = 0
                # end of step

            # DALI iterator need to be exhausted;
            # if not using DALI, simulate drop_last=True with grad accumulation
            if not use_dali and step > steps_per_epoch * epoch:
                break

        epoch_time = time.time() - epoch_start_time
        epoch_loss /= steps_per_epoch
        log(
            (epoch, ), None, 'train_avg', {
                'throughput': epoch_utts / epoch_time,
                'took': epoch_time,
                'loss': epoch_loss
            })
        bmark_stats.update(epoch_utts, epoch_time, epoch_loss)

        if epoch % args.save_frequency == 0 or epoch in args.keep_milestones:
            checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
                              best_wer)

        if 0 < args.epochs_this_job <= epoch - start_epoch:
            print_once(f'Finished after {args.epochs_this_job} epochs.')
            break
        # end of epoch

    if args.pyprof:
        profiler.stop()
        torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)

    log((), None, 'train_avg', bmark_stats.get(args.benchmark_epochs_num))

    if epoch == args.epochs:
        evaluate(epoch, step, val_loader, val_feat_proc, symbols, model,
                 ema_model, ctc_loss, greedy_decoder, args.amp, use_dali)

        checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
                          best_wer)
    flush_log()
Ejemplo n.º 4
0
def main():
    # Parse command line arguments
    args = parse_args()

    # DLProf - Init PyProf
    if args.dlprof:
        nvtx.init(enable_function_stack=True)
        # Set num epochs to 1 if DLProf is enabled
        args.epochs = 1

    # Create experiment
    summary = _make_experiment(args)

    # Create datasets
    train_data = KittiObjectDataset(
        args.root, 'train', args.grid_size, args.grid_res, args.yoffset)
    val_data = KittiObjectDataset(
        args.root, 'val', args.grid_size, args.grid_res, args.yoffset)
    
    # Apply data augmentation
    # train_data = oft.AugmentedObjectDataset(
    #     train_data, args.train_image_size, args.train_grid_size, 
    #     jitter=args.grid_jitter)

    # Create dataloaders
    train_loader = DataLoader(train_data, args.batch_size, shuffle=True, 
        num_workers=args.workers, collate_fn=oft.utils.collate)
    val_loader = DataLoader(val_data, args.batch_size, shuffle=False, 
        num_workers=args.workers,collate_fn=oft.utils.collate)

    # Build model
    model = OftNet(num_classes=1, frontend=args.frontend, 
                   topdown_layers=args.topdown, grid_res=args.grid_res, 
                   grid_height=args.grid_height)
    if len(args.gpu) > 0:
        torch.cuda.set_device(args.gpu[0])
        model = nn.DataParallel(model, args.gpu).cuda()

    # Create encoder
    encoder = ObjectEncoder()

    # Setup optimizer
    optimizer = optim.SGD(
        model.parameters(), args.lr, args.momentum, args.weight_decay)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.lr_decay)

    # Creates a GradScaler once at the beginning of training for AMP. Created even if not being used.
    scaler = GradScaler()
    
    for epoch in range(1, args.epochs+1):

        print('\n=== Beginning epoch {} of {} ==='.format(epoch, args.epochs))
        # Update and log learning rate
        scheduler.step(epoch-1)
        summary.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        # Train model
        if args.dlprof:
            with torch.autograd.profiler.emit_nvtx():
                train(args, train_loader, model, encoder, optimizer, summary, epoch, scaler)
        else:
            train(args, train_loader, model, encoder, optimizer, summary, epoch, scaler)

        # Run validation every N epochs
        if epoch % args.val_interval == 0:
            if args.dlprof:
                with torch.autograd.profiler.emit_nvtx():
                    validate(args, val_loader, model, encoder, summary, epoch)
            else:
                validate(args, val_loader, model, encoder, summary, epoch)

            # Save model checkpoint
            save_checkpoint(args, epoch, model, optimizer, scheduler)
Ejemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
                                     allow_abbrev=False)
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    if args.p_arpabet > 0.0:
        cmudict.initialize(args.cmudict_path, keep_ambiguous=True)

    distributed_run = args.world_size > 1

    torch.manual_seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)

    if args.local_rank == 0:
        if not os.path.exists(args.output):
            os.makedirs(args.output)

    log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
    tb_subsets = ['train', 'val']
    if args.ema_decay > 0.0:
        tb_subsets.append('val_ema')

    logger.init(log_fpath,
                args.output,
                enabled=(args.local_rank == 0),
                tb_subsets=tb_subsets)
    logger.parameters(vars(args), tb_subset='train')

    parser = models.parse_model_args('FastPitch', parser)
    args, unk_args = parser.parse_known_args()
    if len(unk_args) > 0:
        raise ValueError(f'Invalid options {unk_args}')

    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    if distributed_run:
        init_distributed(args, args.world_size, args.local_rank)

    device = torch.device('cuda' if args.cuda else 'cpu')
    model_config = models.get_model_config('FastPitch', args)
    model = models.get_model('FastPitch', model_config, device)

    attention_kl_loss = AttentionBinarizationLoss()

    # Store pitch mean/std as params to translate from Hz during inference
    model.pitch_mean[0] = args.pitch_mean
    model.pitch_std[0] = args.pitch_std

    kw = dict(lr=args.learning_rate,
              betas=(0.9, 0.98),
              eps=1e-9,
              weight_decay=args.weight_decay)
    if args.optimizer == 'adam':
        optimizer = FusedAdam(model.parameters(), **kw)
    elif args.optimizer == 'lamb':
        optimizer = FusedLAMB(model.parameters(), **kw)
    else:
        raise ValueError

    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    if args.ema_decay > 0:
        ema_model = copy.deepcopy(model)
    else:
        ema_model = None

    if distributed_run:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank,
                                        find_unused_parameters=True)

    if args.pyprof:
        pyprof.init(enable_function_stack=True)

    start_epoch = [1]
    start_iter = [0]

    assert args.checkpoint_path is None or args.resume is False, (
        "Specify a single checkpoint source")
    if args.checkpoint_path is not None:
        ch_fpath = args.checkpoint_path
    elif args.resume:
        ch_fpath = last_checkpoint(args.output)
    else:
        ch_fpath = None

    if ch_fpath is not None:
        load_checkpoint(args, model, ema_model, optimizer, scaler, start_epoch,
                        start_iter, model_config, ch_fpath)

    start_epoch = start_epoch[0]
    total_iter = start_iter[0]

    criterion = FastPitchLoss(
        dur_predictor_loss_scale=args.dur_predictor_loss_scale,
        pitch_predictor_loss_scale=args.pitch_predictor_loss_scale,
        attn_loss_scale=args.attn_loss_scale)

    collate_fn = TTSCollate()

    if args.local_rank == 0:
        prepare_tmp(args.pitch_online_dir)

    trainset = TTSDataset(audiopaths_and_text=args.training_files,
                          **vars(args))
    valset = TTSDataset(audiopaths_and_text=args.validation_files,
                        **vars(args))

    if distributed_run:
        train_sampler, shuffle = DistributedSampler(trainset), False
    else:
        train_sampler, shuffle = None, True

    # 4 workers are optimal on DGX-1 (from epoch 2 onwards)
    train_loader = DataLoader(trainset,
                              num_workers=4,
                              shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              pin_memory=True,
                              persistent_workers=True,
                              drop_last=True,
                              collate_fn=collate_fn)

    if args.ema_decay:
        mt_ema_params = init_multi_tensor_ema(model, ema_model)

    model.train()

    if args.pyprof:
        torch.autograd.profiler.emit_nvtx().__enter__()
        profiler.start()

    epoch_loss = []
    epoch_mel_loss = []
    epoch_num_frames = []
    epoch_frames_per_sec = []
    epoch_time = []

    torch.cuda.synchronize()
    for epoch in range(start_epoch, args.epochs + 1):
        epoch_start_time = time.perf_counter()

        epoch_loss += [0.0]
        epoch_mel_loss += [0.0]
        epoch_num_frames += [0]
        epoch_frames_per_sec += [0.0]

        if distributed_run:
            train_loader.sampler.set_epoch(epoch)

        accumulated_steps = 0
        iter_loss = 0
        iter_num_frames = 0
        iter_meta = {}
        iter_start_time = None

        epoch_iter = 0
        num_iters = len(train_loader) // args.grad_accumulation
        for batch in train_loader:

            if accumulated_steps == 0:
                if epoch_iter == num_iters:
                    break
                total_iter += 1
                epoch_iter += 1
                if iter_start_time is None:
                    iter_start_time = time.perf_counter()

                adjust_learning_rate(total_iter, optimizer, args.learning_rate,
                                     args.warmup_steps)

                model.zero_grad(set_to_none=True)

            x, y, num_frames = batch_to_gpu(batch)

            with torch.cuda.amp.autocast(enabled=args.amp):
                y_pred = model(x)
                loss, meta = criterion(y_pred, y)

                if (args.kl_loss_start_epoch is not None
                        and epoch >= args.kl_loss_start_epoch):

                    if args.kl_loss_start_epoch == epoch and epoch_iter == 1:
                        print('Begin hard_attn loss')

                    _, _, _, _, _, _, _, _, attn_soft, attn_hard, _, _ = y_pred
                    binarization_loss = attention_kl_loss(attn_hard, attn_soft)
                    kl_weight = min(
                        (epoch - args.kl_loss_start_epoch) /
                        args.kl_loss_warmup_epochs, 1.0) * args.kl_loss_weight
                    meta['kl_loss'] = binarization_loss.clone().detach(
                    ) * kl_weight
                    loss += kl_weight * binarization_loss

                else:
                    meta['kl_loss'] = torch.zeros_like(loss)
                    kl_weight = 0
                    binarization_loss = 0

                loss /= args.grad_accumulation

            meta = {k: v / args.grad_accumulation for k, v in meta.items()}

            if args.amp:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, args.world_size).item()
                reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
                meta = {
                    k: reduce_tensor(v, args.world_size)
                    for k, v in meta.items()
                }
            else:
                reduced_loss = loss.item()
                reduced_num_frames = num_frames.item()
            if np.isnan(reduced_loss):
                raise Exception("loss is NaN")

            accumulated_steps += 1
            iter_loss += reduced_loss
            iter_num_frames += reduced_num_frames
            iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}

            if accumulated_steps % args.grad_accumulation == 0:

                logger.log_grads_tb(total_iter, model)
                if args.amp:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.grad_clip_thresh)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.grad_clip_thresh)
                    optimizer.step()

                if args.ema_decay > 0.0:
                    apply_multi_tensor_ema(args.ema_decay, *mt_ema_params)

                iter_time = time.perf_counter() - iter_start_time
                iter_mel_loss = iter_meta['mel_loss'].item()
                iter_kl_loss = iter_meta['kl_loss'].item()
                epoch_frames_per_sec[-1] += iter_num_frames / iter_time
                epoch_loss[-1] += iter_loss
                epoch_num_frames[-1] += iter_num_frames
                epoch_mel_loss[-1] += iter_mel_loss

                logger.log(
                    (epoch, epoch_iter, num_iters),
                    tb_total_steps=total_iter,
                    subset='train',
                    data=OrderedDict([
                        ('loss', iter_loss), ('mel_loss', iter_mel_loss),
                        ('kl_loss', iter_kl_loss), ('kl_weight', kl_weight),
                        ('frames/s', iter_num_frames / iter_time),
                        ('took', iter_time),
                        ('lrate', optimizer.param_groups[0]['lr'])
                    ]),
                )

                accumulated_steps = 0
                iter_loss = 0
                iter_num_frames = 0
                iter_meta = {}
                iter_start_time = time.perf_counter()

        # Finished epoch
        epoch_loss[-1] /= epoch_iter
        epoch_mel_loss[-1] /= epoch_iter
        epoch_time += [time.perf_counter() - epoch_start_time]
        iter_start_time = None

        logger.log(
            (epoch, ),
            tb_total_steps=None,
            subset='train_avg',
            data=OrderedDict([('loss', epoch_loss[-1]),
                              ('mel_loss', epoch_mel_loss[-1]),
                              ('frames/s',
                               epoch_num_frames[-1] / epoch_time[-1]),
                              ('took', epoch_time[-1])]),
        )

        validate(model, epoch, total_iter, criterion, valset, args.batch_size,
                 collate_fn, distributed_run, batch_to_gpu)

        if args.ema_decay > 0:
            validate(ema_model,
                     epoch,
                     total_iter,
                     criterion,
                     valset,
                     args.batch_size,
                     collate_fn,
                     distributed_run,
                     batch_to_gpu,
                     ema=True)

        maybe_save_checkpoint(args, model, ema_model, optimizer, scaler, epoch,
                              total_iter, model_config)
        logger.flush()

    # Finished training
    if args.pyprof:
        profiler.stop()
        torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)

    if len(epoch_loss) > 0:
        # Was trained - average the last 20 measurements
        last_ = lambda l: np.asarray(l[-20:])
        epoch_loss = last_(epoch_loss)
        epoch_mel_loss = last_(epoch_mel_loss)
        epoch_num_frames = last_(epoch_num_frames)
        epoch_time = last_(epoch_time)
        logger.log(
            (),
            tb_total_steps=None,
            subset='train_avg',
            data=OrderedDict([('loss', epoch_loss.mean()),
                              ('mel_loss', epoch_mel_loss.mean()),
                              ('frames/s',
                               epoch_num_frames.sum() / epoch_time.sum()),
                              ('took', epoch_time.mean())]),
        )

    validate(model, None, total_iter, criterion, valset, args.batch_size,
             collate_fn, distributed_run, batch_to_gpu)
Ejemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
                                     allow_abbrev=False)
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    distributed_run = args.world_size > 1

    torch.manual_seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)

    if args.local_rank == 0:
        if not os.path.exists(args.output):
            os.makedirs(args.output)

    log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
    tb_subsets = ['train', 'val']
    if args.ema_decay > 0.0:
        tb_subsets.append('val_ema')

    logger.init(log_fpath, args.output, enabled=(args.local_rank == 0),
                tb_subsets=tb_subsets)
    logger.parameters(vars(args), tb_subset='train')

    parser = models.parse_model_args('FastPitch', parser)
    args, unk_args = parser.parse_known_args()
    if len(unk_args) > 0:
        raise ValueError(f'Invalid options {unk_args}')

    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    if distributed_run:
        init_distributed(args, args.world_size, args.local_rank)

    device = torch.device('cuda' if args.cuda else 'cpu')
    model_config = models.get_model_config('FastPitch', args)
    model = models.get_model('FastPitch', model_config, device)

    # Store pitch mean/std as params to translate from Hz during inference
    with open(args.pitch_mean_std_file, 'r') as f:
        stats = json.load(f)
    model.pitch_mean[0] = stats['mean']
    model.pitch_std[0] = stats['std']

    kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9,
              weight_decay=args.weight_decay)
    if args.optimizer == 'adam':
        optimizer = FusedAdam(model.parameters(), **kw)
    elif args.optimizer == 'lamb':
        optimizer = FusedLAMB(model.parameters(), **kw)
    else:
        raise ValueError

    if args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    if args.ema_decay > 0:
        ema_model = copy.deepcopy(model)
    else:
        ema_model = None

    if distributed_run:
        model = DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank,
            find_unused_parameters=True)

    if args.pyprof:
        pyprof.init(enable_function_stack=True)

    start_epoch = [1]
    start_iter = [0]

    assert args.checkpoint_path is None or args.resume is False, (
        "Specify a single checkpoint source")
    if args.checkpoint_path is not None:
        ch_fpath = args.checkpoint_path
    elif args.resume:
        ch_fpath = last_checkpoint(args.output)
    else:
        ch_fpath = None

    if ch_fpath is not None:
        load_checkpoint(args.local_rank, model, ema_model, optimizer, start_epoch,
                        start_iter, model_config, args.amp, ch_fpath,
                        args.world_size)

    start_epoch = start_epoch[0]
    total_iter = start_iter[0]

    criterion = loss_functions.get_loss_function('FastPitch',
        dur_predictor_loss_scale=args.dur_predictor_loss_scale,
        pitch_predictor_loss_scale=args.pitch_predictor_loss_scale)

    collate_fn = data_functions.get_collate_function('FastPitch')
    trainset = data_functions.get_data_loader('FastPitch',
                                              audiopaths_and_text=args.training_files,
                                              **vars(args))
    valset = data_functions.get_data_loader('FastPitch',
                                            audiopaths_and_text=args.validation_files,
                                            **vars(args))
    if distributed_run:
        train_sampler, shuffle = DistributedSampler(trainset), False
    else:
        train_sampler, shuffle = None, True

    train_loader = DataLoader(trainset, num_workers=16, shuffle=shuffle,
                              sampler=train_sampler, batch_size=args.batch_size,
                              pin_memory=False, drop_last=True,
                              collate_fn=collate_fn)

    batch_to_gpu = data_functions.get_batch_to_gpu('FastPitch')

    if args.ema_decay:
        ema_model_weight_list, model_weight_list, overflow_buf_for_ema = init_multi_tensor_ema(model, ema_model)
    else:
        ema_model_weight_list, model_weight_list, overflow_buf_for_ema = None, None, None

    model.train()

    if args.pyprof:
        torch.autograd.profiler.emit_nvtx().__enter__()
        profiler.start()

    torch.cuda.synchronize()
    for epoch in range(start_epoch, args.epochs + 1):
        epoch_start_time = time.perf_counter()

        epoch_loss = 0.0
        epoch_mel_loss = 0.0
        epoch_num_frames = 0
        epoch_frames_per_sec = 0.0

        if distributed_run:
            train_loader.sampler.set_epoch(epoch)

        accumulated_steps = 0
        iter_loss = 0
        iter_num_frames = 0
        iter_meta = {}

        epoch_iter = 0
        num_iters = len(train_loader) // args.gradient_accumulation_steps
        for batch in train_loader:

            if accumulated_steps == 0:
                if epoch_iter == num_iters:
                    break
                total_iter += 1
                epoch_iter += 1
                iter_start_time = time.perf_counter()

                adjust_learning_rate(total_iter, optimizer, args.learning_rate,
                                     args.warmup_steps)

                model.zero_grad()

            x, y, num_frames = batch_to_gpu(batch)
            y_pred = model(x, use_gt_durations=True)
            loss, meta = criterion(y_pred, y)

            loss /= args.gradient_accumulation_steps
            meta = {k: v / args.gradient_accumulation_steps
                    for k, v in meta.items()}

            if args.amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, args.world_size).item()
                reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
                meta = {k: reduce_tensor(v, args.world_size) for k,v in meta.items()}
            else:
                reduced_loss = loss.item()
                reduced_num_frames = num_frames.item()
            if np.isnan(reduced_loss):
                raise Exception("loss is NaN")

            accumulated_steps += 1
            iter_loss += reduced_loss
            iter_num_frames += reduced_num_frames
            iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}

            if accumulated_steps % args.gradient_accumulation_steps == 0:

                logger.log_grads_tb(total_iter, model)
                if args.amp:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.grad_clip_thresh)
                else:
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), args.grad_clip_thresh)

                optimizer.step()
                apply_multi_tensor_ema(model_weight_list, ema_model_weight_list, args.ema_decay, overflow_buf_for_ema)

                iter_time = time.perf_counter() - iter_start_time
                iter_mel_loss = iter_meta['mel_loss'].item()
                epoch_frames_per_sec += iter_num_frames / iter_time
                epoch_loss += iter_loss
                epoch_num_frames += iter_num_frames
                epoch_mel_loss += iter_mel_loss

                logger.log((epoch, epoch_iter, num_iters),
                           tb_total_steps=total_iter,
                           subset='train',
                           data=OrderedDict([
                               ('loss', iter_loss),
                               ('mel_loss', iter_mel_loss),
                               ('frames/s', iter_num_frames / iter_time),
                               ('took', iter_time),
                               ('lrate', optimizer.param_groups[0]['lr'])]),
                )

                accumulated_steps = 0
                iter_loss = 0
                iter_num_frames = 0
                iter_meta = {}

        # Finished epoch
        epoch_time = time.perf_counter() - epoch_start_time

        logger.log((epoch,),
                   tb_total_steps=None,
                   subset='train_avg',
                   data=OrderedDict([
                       ('loss', epoch_loss / epoch_iter),
                       ('mel_loss', epoch_mel_loss / epoch_iter),
                       ('frames/s', epoch_num_frames / epoch_time),
                       ('took', epoch_time)]),
        )

        validate(model, epoch, total_iter, criterion, valset, args.batch_size,
                 collate_fn, distributed_run, batch_to_gpu,
                 use_gt_durations=True)

        if args.ema_decay > 0:
            validate(ema_model, epoch, total_iter, criterion, valset,
                     args.batch_size, collate_fn, distributed_run, batch_to_gpu,
                     use_gt_durations=True, ema=True)

        if (epoch > 0 and args.epochs_per_checkpoint > 0 and
            (epoch % args.epochs_per_checkpoint == 0) and args.local_rank == 0):

            checkpoint_path = os.path.join(
                args.output, f"FastPitch_checkpoint_{epoch}.pt")
            save_checkpoint(args.local_rank, model, ema_model, optimizer, epoch,
                            total_iter, model_config, args.amp, checkpoint_path)
        logger.flush()

    # Finished training
    if args.pyprof:
        profiler.stop()
        torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)

    logger.log((),
               tb_total_steps=None,
               subset='train_avg',
               data=OrderedDict([
                   ('loss', epoch_loss / epoch_iter),
                   ('mel_loss', epoch_mel_loss / epoch_iter),
                   ('frames/s', epoch_num_frames / epoch_time),
                   ('took', epoch_time)]),
    )
    validate(model, None, total_iter, criterion, valset, args.batch_size,
             collate_fn, distributed_run, batch_to_gpu, use_gt_durations=True)

    if (epoch > 0 and args.epochs_per_checkpoint > 0 and
        (epoch % args.epochs_per_checkpoint != 0) and args.local_rank == 0):
        checkpoint_path = os.path.join(
            args.output, f"FastPitch_checkpoint_{epoch}.pt")
        save_checkpoint(args.local_rank, model, ema_model, optimizer, epoch,
                        total_iter, model_config, args.amp, checkpoint_path)