Esempio n. 1
0
    def _init_iterator(self, gpu_id, dataset_path, config_data, config_features, json_names: list, symbols: list,
                       train_pipeline: bool):
        """
        Returns data iterator. Data underneath this operator is preprocessed within Dali
        """

        def hash_list_of_strings(li):
            return str(abs(hash(''.join(li))))

        output_files, transcripts = {}, {}
        max_duration = config_data['max_duration']
        for jname in json_names:
            of, tr = _parse_json(jname if jname[0] == '/' else os.path.join(dataset_path, jname), len(output_files),
                                 predicate=lambda json: json['original_duration'] <= max_duration)
            output_files.update(of)
            transcripts.update(tr)
        file_list_path = os.path.join("/tmp", "jasper_dali.file_list." + hash_list_of_strings(json_names))
        _dict_to_file(output_files, file_list_path)
        self.dataset_size = len(output_files)
        print_once(f"Dataset read by DALI. Number of samples: {self.dataset_size}")

        pipeline = DaliPipeline.from_config(config_data=config_data, config_features=config_features, device_id=gpu_id,
                                            file_root=dataset_path, file_list=file_list_path,
                                            device_type=self.device_type, batch_size=self.batch_size,
                                            train_pipeline=train_pipeline)

        return DaliJasperIterator([pipeline], transcripts=transcripts, symbols=symbols, batch_size=self.batch_size,
                                  reader_name="file_reader", train_iterator=train_pipeline)
Esempio n. 2
0
def normalize_string(s, symbols, punct_map):
    """
    Normalizes string.
    Example:
        'call me at 8:00 pm!' -> 'call me at eight zero pm'
    """
    labels = set(symbols)
    try:
        text = _clean_text(s, ["english_cleaners"], punct_map).strip()
        return ''.join([tok for tok in text if all(t in labels for t in tok)])
    except Exception as e:
        print_once("WARNING: Normalizing failed: {s} {e}")
Esempio n. 3
0
    def _init_iterator(self, gpu_id, dataset_path, config_data,
                       config_features, json_names: list, tokenizer: list,
                       pipeline_type):
        """
        Returns data iterator. Data underneath this operator is preprocessed within Dali
        """

        output_files, transcripts = {}, {}
        max_duration = config_data['max_duration']
        for jname in json_names:
            of, tr = _parse_json(jname if jname[0] == '/' else os.path.join(
                dataset_path, jname),
                                 len(output_files),
                                 predicate=lambda json: json[
                                     'original_duration'] <= max_duration)
            output_files.update(of)
            transcripts.update(tr)
        self.sampler.make_file_list(output_files, json_names)
        self.dataset_size = self.sampler.get_dataset_size()
        print_once(
            f"Dataset read by DALI. Number of samples: {self.dataset_size}")

        pipeline = DaliPipeline.from_config(config_data=config_data,
                                            config_features=config_features,
                                            device_id=gpu_id,
                                            file_root=dataset_path,
                                            sampler=self.sampler,
                                            device_type=self.device_type,
                                            batch_size=self.batch_size,
                                            pipeline_type=pipeline_type)

        return DaliRnntIterator([pipeline],
                                transcripts=transcripts,
                                tokenizer=tokenizer,
                                batch_size=self.batch_size,
                                shard_size=self._shard_size(),
                                pipeline_type=pipeline_type)
Esempio n. 4
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    log_fpath = args.log_file or str(Path(args.output_dir, 'nvlog_infer.json'))
    log_fpath = unique_log_fpath(log_fpath)
    dllogger.init(backends=[
        JSONStreamBackend(Verbosity.DEFAULT, log_fpath),
        StdOutBackend(Verbosity.VERBOSE, metric_format=stdout_metric_format)
    ])

    [dllogger.log("PARAMETER", {k: v}) for k, v in vars(args).items()]

    for step in ['DNN', 'data+DNN', 'data']:
        for c in [0.99, 0.95, 0.9, 0.5]:
            cs = 'avg' if c == 0.5 else f'{int(100*c)}%'
            dllogger.metadata(f'{step.lower()}_latency_{c}', {
                'name': f'{step} latency {cs}',
                'format': ':>7.2f',
                'unit': 'ms'
            })
    dllogger.metadata('eval_wer', {
        'name': 'WER',
        'format': ':>3.3f',
        'unit': '%'
    })

    if args.cpu:
        device = torch.device('cpu')
    else:
        assert torch.cuda.is_available()
        device = torch.device('cuda')
        torch.backends.cudnn.benchmark = args.cudnn_benchmark

    if args.seed is not None:
        torch.manual_seed(args.seed + args.local_rank)
        np.random.seed(args.seed + args.local_rank)
        random.seed(args.seed + args.local_rank)

    # set up distributed training
    multi_gpu = not args.cpu and int(os.environ.get('WORLD_SIZE', 1)) > 1
    if multi_gpu:
        torch.cuda.set_device(args.local_rank)
        distrib.init_process_group(backend='nccl', init_method='env://')
        print_once(f'Inference with {distrib.get_world_size()} GPUs')

    cfg = config.load(args.model_config)

    if args.max_duration is not None:
        cfg['input_val']['audio_dataset']['max_duration'] = args.max_duration
        cfg['input_val']['filterbank_features'][
            'max_duration'] = args.max_duration

    if args.pad_to_max_duration:
        assert cfg['input_val']['audio_dataset']['max_duration'] > 0
        cfg['input_val']['audio_dataset']['pad_to_max_duration'] = True
        cfg['input_val']['filterbank_features']['pad_to_max_duration'] = True

    use_dali = args.dali_device in ('cpu', 'gpu')

    (dataset_kw, features_kw, splicing_kw, _, _) = config.input(cfg, 'val')

    tokenizer_kw = config.tokenizer(cfg)
    tokenizer = Tokenizer(**tokenizer_kw)

    optim_level = 3 if args.amp else 0

    feature_proc = torch.nn.Sequential(
        torch.nn.Identity(),
        torch.nn.Identity(),
        features.FrameSplicing(optim_level=optim_level, **splicing_kw),
        features.FillPadding(optim_level=optim_level, ),
    )

    # dataset

    data_loader = DaliDataLoader(gpu_id=args.local_rank or 0,
                                 dataset_path=args.dataset_dir,
                                 config_data=dataset_kw,
                                 config_features=features_kw,
                                 json_names=[args.val_manifest],
                                 batch_size=args.batch_size,
                                 sampler=dali_sampler.SimpleSampler(),
                                 pipeline_type="val",
                                 device_type=args.dali_device,
                                 tokenizer=tokenizer)

    model = RNNT(n_classes=tokenizer.num_labels + 1, **config.rnnt(cfg))

    if args.ckpt is not None:
        print(f'Loading the model from {args.ckpt} ...')
        checkpoint = torch.load(args.ckpt, map_location="cpu")
        key = 'ema_state_dict' if args.ema else 'state_dict'
        state_dict = checkpoint[key]
        model.load_state_dict(state_dict, strict=True)

    model.to(device)
    model.eval()

    if feature_proc is not None:
        feature_proc.to(device)
        feature_proc.eval()

    if args.amp:
        model = amp.initialize(model, opt_level='O3')

    if multi_gpu:
        model = DistributedDataParallel(model)

    agg = {'txts': [], 'preds': [], 'logits': []}
    dur = {'data': [], 'dnn': [], 'data+dnn': []}

    rep_loader = chain(*repeat(data_loader, args.repeats))
    rep_len = args.repeats * len(data_loader)

    blank_idx = tokenizer.num_labels
    greedy_decoder = RNNTGreedyDecoder(blank_idx=blank_idx)

    def sync_time():
        torch.cuda.synchronize() if device.type == 'cuda' else None
        return time.perf_counter()

    sz = []
    with torch.no_grad():

        for it, batch in enumerate(tqdm.tqdm(rep_loader, total=rep_len)):

            if use_dali:
                feats, feat_lens, txt, txt_lens = batch
                if feature_proc is not None:
                    feats, feat_lens = feature_proc([feats, feat_lens])
            else:
                batch = [t.cuda(non_blocking=True) for t in batch]
                audio, audio_lens, txt, txt_lens = batch
                feats, feat_lens = feature_proc([audio, audio_lens])
            feats = feats.permute(2, 0, 1)
            if args.amp:
                feats = feats.half()

            sz.append(feats.size(0))

            t1 = sync_time()
            log_probs, log_prob_lens = model(feats, feat_lens, txt, txt_lens)
            t2 = sync_time()

            # burn-in period; wait for a new loader due to num_workers
            if it >= 1 and (args.repeats == 1 or it >= len(data_loader)):
                dur['data'].append(t1 - t0)
                dur['dnn'].append(t2 - t1)
                dur['data+dnn'].append(t2 - t0)

            if txt is not None:
                agg['txts'] += helpers.gather_transcripts([txt], [txt_lens],
                                                          tokenizer.detokenize)

            preds = greedy_decoder.decode(model, feats, feat_lens)

            agg['preds'] += helpers.gather_predictions([preds],
                                                       tokenizer.detokenize)

            if 0 < args.steps < it:
                break

            t0 = sync_time()

        # communicate the results
        if args.transcribe_wav:
            for idx, p in enumerate(agg['preds']):
                print_once(f'Prediction {idx+1: >3}: {p}')

        elif args.transcribe_filelist:
            pass

        else:
            wer, loss = process_evaluation_epoch(agg)

            if not multi_gpu or distrib.get_rank() == 0:
                dllogger.log(step=(), data={'eval_wer': 100 * wer})

        if args.save_predictions:
            with open(args.save_predictions, 'w') as f:
                f.write('\n'.join(agg['preds']))

    # report timings
    if len(dur['data']) >= 20:
        ratios = [0.9, 0.95, 0.99]

        for stage in dur:
            lat = durs_to_percentiles(dur[stage], ratios)
            for k in [0.99, 0.95, 0.9, 0.5]:
                kk = str(k).replace('.', '_')
                dllogger.log(step=(),
                             data={f'{stage.lower()}_latency_{kk}': lat[k]})

    else:
        # TODO measure at least avg latency
        print_once('Not enough samples to measure latencies.')
Esempio n. 5
0
def main():

    parser = get_parser()
    args = parser.parse_args()

    log_fpath = args.log_file or str(Path(args.output_dir, 'nvlog_infer.json'))
    log_fpath = unique_log_fpath(log_fpath)
    dllogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT, log_fpath),
                            StdOutBackend(Verbosity.VERBOSE,
                                          metric_format=stdout_metric_format)])

    [dllogger.log("PARAMETER", {k: v}) for k, v in vars(args).items()]

    for step in ['DNN', 'data+DNN', 'data']:
        for c in [0.99, 0.95, 0.9, 0.5]:
            cs = 'avg' if c == 0.5 else f'{int(100*c)}%'
            dllogger.metadata(f'{step.lower()}_latency_{c}',
                              {'name': f'{step} latency {cs}',
                               'format': ':>7.2f', 'unit': 'ms'})
    dllogger.metadata(
        'eval_wer', {'name': 'WER', 'format': ':>3.2f', 'unit': '%'})

    if args.cpu:
        device = torch.device('cpu')
    else:
        assert torch.cuda.is_available()
        device = torch.device('cuda')
        torch.backends.cudnn.benchmark = args.cudnn_benchmark

    if args.seed is not None:
        torch.manual_seed(args.seed + args.local_rank)
        np.random.seed(args.seed + args.local_rank)
        random.seed(args.seed + args.local_rank)

    # set up distributed training
    multi_gpu = not args.cpu and int(os.environ.get('WORLD_SIZE', 1)) > 1
    if multi_gpu:
        torch.cuda.set_device(args.local_rank)
        distrib.init_process_group(backend='nccl', init_method='env://')
        print_once(f'Inference with {distrib.get_world_size()} GPUs')

    cfg = config.load(args.model_config)
    config.apply_config_overrides(cfg, args)

    symbols = helpers.add_ctc_blank(cfg['labels'])

    use_dali = args.dali_device in ('cpu', 'gpu')
    dataset_kw, features_kw = config.input(cfg, 'val')

    measure_perf = args.steps > 0

    # dataset
    if args.transcribe_wav or args.transcribe_filelist:

        if use_dali:
            print("DALI supported only with input .json files; disabling")
            use_dali = False

        assert not args.pad_to_max_duration
        assert not (args.transcribe_wav and args.transcribe_filelist)

        if args.transcribe_wav:
            dataset = SingleAudioDataset(args.transcribe_wav)
        else:
            dataset = FilelistDataset(args.transcribe_filelist)

        data_loader = get_data_loader(dataset,
                                      batch_size=1,
                                      multi_gpu=multi_gpu,
                                      shuffle=False,
                                      num_workers=0,
                                      drop_last=(True if measure_perf else False))

        _, features_kw = config.input(cfg, 'val')
        feat_proc = FilterbankFeatures(**features_kw)

    elif use_dali:
        # pad_to_max_duration is not supported by DALI - have simple padders
        if features_kw['pad_to_max_duration']:
            feat_proc = BaseFeatures(
                pad_align=features_kw['pad_align'],
                pad_to_max_duration=True,
                max_duration=features_kw['max_duration'],
                sample_rate=features_kw['sample_rate'],
                window_size=features_kw['window_size'],
                window_stride=features_kw['window_stride'])
            features_kw['pad_to_max_duration'] = False
        else:
            feat_proc = None

        data_loader = DaliDataLoader(
            gpu_id=args.local_rank or 0,
            dataset_path=args.dataset_dir,
            config_data=dataset_kw,
            config_features=features_kw,
            json_names=args.val_manifests,
            batch_size=args.batch_size,
            pipeline_type=("train" if measure_perf else "val"),  # no drop_last
            device_type=args.dali_device,
            symbols=symbols)

    else:
        dataset = AudioDataset(args.dataset_dir,
                               args.val_manifests,
                               symbols,
                               **dataset_kw)

        data_loader = get_data_loader(dataset,
                                      args.batch_size,
                                      multi_gpu=multi_gpu,
                                      shuffle=False,
                                      num_workers=4,
                                      drop_last=False)

        feat_proc = FilterbankFeatures(**features_kw)

    model = QuartzNet(encoder_kw=config.encoder(cfg),
                      decoder_kw=config.decoder(cfg, n_classes=len(symbols)))

    if args.ckpt is not None:
        print(f'Loading the model from {args.ckpt} ...')
        checkpoint = torch.load(args.ckpt, map_location="cpu")
        key = 'ema_state_dict' if args.ema else 'state_dict'
        state_dict = checkpoint[key]
        model.load_state_dict(state_dict, strict=True)

    model.to(device)
    model.eval()

    if feat_proc is not None:
        feat_proc.to(device)
        feat_proc.eval()

    if args.amp:
        model = model.half()

    if args.torchscript:
        greedy_decoder = GreedyCTCDecoder()

        feat_proc, model, greedy_decoder = torchscript_export(
            data_loader, feat_proc, model, greedy_decoder, args.output_dir,
            use_amp=args.amp, use_conv_masks=True, model_toml=args.model_toml,
            device=device, save=args.torchscript_export)

    if multi_gpu:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)

    agg = {'txts': [], 'preds': [], 'logits': []}
    dur = {'data': [], 'dnn': [], 'data+dnn': []}

    looped_loader = chain.from_iterable(repeat(data_loader))
    greedy_decoder = GreedyCTCDecoder()

    sync = lambda: torch.cuda.synchronize() if device.type == 'cuda' else None

    steps = args.steps + args.warmup_steps or len(data_loader)
    with torch.no_grad():

        for it, batch in enumerate(tqdm(looped_loader, initial=1, total=steps)):

            if use_dali:
                feats, feat_lens, txt, txt_lens = batch
                if feat_proc is not None:
                    feats, feat_lens = feat_proc(feats, feat_lens)
            else:
                batch = [t.to(device, non_blocking=True) for t in batch]
                audio, audio_lens, txt, txt_lens = batch
                feats, feat_lens = feat_proc(audio, audio_lens)

            sync()
            t1 = time.perf_counter()

            if args.amp:
                feats = feats.half()

            if model.encoder.use_conv_masks:
                log_probs, log_prob_lens = model(feats, feat_lens)
            else:
                log_probs = model(feats, feat_lens)

            preds = greedy_decoder(log_probs)

            sync()
            t2 = time.perf_counter()

            # burn-in period; wait for a new loader due to num_workers
            if it >= 1 and (args.steps == 0 or it >= args.warmup_steps):
                dur['data'].append(t1 - t0)
                dur['dnn'].append(t2 - t1)
                dur['data+dnn'].append(t2 - t0)

            if txt is not None:
                agg['txts'] += helpers.gather_transcripts([txt], [txt_lens],
                                                          symbols)
            agg['preds'] += helpers.gather_predictions([preds], symbols)
            agg['logits'].append(log_probs)

            if it + 1 == steps:
                break

            sync()
            t0 = time.perf_counter()

        # communicate the results
        if args.transcribe_wav:
            for idx, p in enumerate(agg['preds']):
                print_once(f'Prediction {idx+1: >3}: {p}')

        elif args.transcribe_filelist:
            pass

        elif not multi_gpu or distrib.get_rank() == 0:
            wer, _ = process_evaluation_epoch(agg)

            dllogger.log(step=(), data={'eval_wer': 100 * wer})

        if args.save_predictions:
            with open(args.save_predictions, 'w') as f:
                f.write('\n'.join(agg['preds']))

        if args.save_logits:
            logits = torch.cat(agg['logits'], dim=0).cpu()
            torch.save(logits, args.save_logits)

    # report timings
    if len(dur['data']) >= 20:
        ratios = [0.9, 0.95, 0.99]
        for stage in dur:
            lat = durs_to_percentiles(dur[stage], ratios)
            for k in [0.99, 0.95, 0.9, 0.5]:
                kk = str(k).replace('.', '_')
                dllogger.log(step=(), data={f'{stage.lower()}_latency_{kk}': lat[k]})

    else:
        print_once('Not enough samples to measure latencies.')
Esempio n. 6
0
def main():
    args = parse_args()

    assert (torch.cuda.is_available())
    assert args.prediction_frequency % args.log_frequency == 0

    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    # set up distributed training
    multi_gpu = int(os.environ.get('WORLD_SIZE', 1)) > 1
    if multi_gpu:
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend='nccl', init_method='env://')
        world_size = dist.get_world_size()
        print_once(f'Distributed training with {world_size} GPUs\n')
    else:
        world_size = 1

    torch.manual_seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    random.seed(args.seed + args.local_rank)

    init_log(args)

    cfg = config.load(args.model_config)
    config.apply_config_overrides(cfg, args)

    symbols = helpers.add_ctc_blank(cfg['labels'])

    assert args.grad_accumulation_steps >= 1
    assert args.batch_size % args.grad_accumulation_steps == 0
    batch_size = args.batch_size // args.grad_accumulation_steps

    print_once('Setting up datasets...')
    train_dataset_kw, train_features_kw = config.input(cfg, 'train')
    val_dataset_kw, val_features_kw = config.input(cfg, 'val')

    use_dali = args.dali_device in ('cpu', 'gpu')
    if use_dali:
        assert train_dataset_kw['ignore_offline_speed_perturbation'], \
            "DALI doesn't support offline speed perturbation"

        # pad_to_max_duration is not supported by DALI - have simple padders
        if train_features_kw['pad_to_max_duration']:
            train_feat_proc = BaseFeatures(
                pad_align=train_features_kw['pad_align'],
                pad_to_max_duration=True,
                max_duration=train_features_kw['max_duration'],
                sample_rate=train_features_kw['sample_rate'],
                window_size=train_features_kw['window_size'],
                window_stride=train_features_kw['window_stride'])
            train_features_kw['pad_to_max_duration'] = False
        else:
            train_feat_proc = None

        if val_features_kw['pad_to_max_duration']:
            val_feat_proc = BaseFeatures(
                pad_align=val_features_kw['pad_align'],
                pad_to_max_duration=True,
                max_duration=val_features_kw['max_duration'],
                sample_rate=val_features_kw['sample_rate'],
                window_size=val_features_kw['window_size'],
                window_stride=val_features_kw['window_stride'])
            val_features_kw['pad_to_max_duration'] = False
        else:
            val_feat_proc = None

        train_loader = DaliDataLoader(
            gpu_id=args.local_rank,
            dataset_path=args.dataset_dir,
            config_data=train_dataset_kw,
            config_features=train_features_kw,
            json_names=args.train_manifests,
            batch_size=batch_size,
            grad_accumulation_steps=args.grad_accumulation_steps,
            pipeline_type="train",
            device_type=args.dali_device,
            symbols=symbols)

        val_loader = DaliDataLoader(gpu_id=args.local_rank,
                                    dataset_path=args.dataset_dir,
                                    config_data=val_dataset_kw,
                                    config_features=val_features_kw,
                                    json_names=args.val_manifests,
                                    batch_size=batch_size,
                                    pipeline_type="val",
                                    device_type=args.dali_device,
                                    symbols=symbols)
    else:
        train_dataset_kw, train_features_kw = config.input(cfg, 'train')
        train_dataset = AudioDataset(args.dataset_dir, args.train_manifests,
                                     symbols, **train_dataset_kw)
        train_loader = get_data_loader(train_dataset,
                                       batch_size,
                                       multi_gpu=multi_gpu,
                                       shuffle=True,
                                       num_workers=4)
        train_feat_proc = FilterbankFeatures(**train_features_kw)

        val_dataset_kw, val_features_kw = config.input(cfg, 'val')
        val_dataset = AudioDataset(args.dataset_dir, args.val_manifests,
                                   symbols, **val_dataset_kw)
        val_loader = get_data_loader(val_dataset,
                                     batch_size,
                                     multi_gpu=multi_gpu,
                                     shuffle=False,
                                     num_workers=4,
                                     drop_last=False)
        val_feat_proc = FilterbankFeatures(**val_features_kw)

        dur = train_dataset.duration / 3600
        dur_f = train_dataset.duration_filtered / 3600
        nsampl = len(train_dataset)
        print_once(f'Training samples: {nsampl} ({dur:.1f}h, '
                   f'filtered {dur_f:.1f}h)')

    if train_feat_proc is not None:
        train_feat_proc.cuda()
    if val_feat_proc is not None:
        val_feat_proc.cuda()

    steps_per_epoch = len(train_loader) // args.grad_accumulation_steps

    # set up the model
    model = Jasper(encoder_kw=config.encoder(cfg),
                   decoder_kw=config.decoder(cfg, n_classes=len(symbols)))
    model.cuda()
    ctc_loss = CTCLossNM(n_classes=len(symbols))
    greedy_decoder = GreedyCTCDecoder()

    print_once(f'Model size: {num_weights(model) / 10**6:.1f}M params\n')

    # optimization
    kw = {'lr': args.lr, 'weight_decay': args.weight_decay}
    if args.optimizer == "novograd":
        optimizer = Novograd(model.parameters(), **kw)
    elif args.optimizer == "adamw":
        optimizer = AdamW(model.parameters(), **kw)
    else:
        raise ValueError(f'Invalid optimizer "{args.optimizer}"')

    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    adjust_lr = lambda step, epoch, optimizer: lr_policy(
        step,
        epoch,
        args.lr,
        optimizer,
        steps_per_epoch=steps_per_epoch,
        warmup_epochs=args.warmup_epochs,
        hold_epochs=args.hold_epochs,
        num_epochs=args.epochs,
        policy=args.lr_policy,
        min_lr=args.min_lr,
        exp_gamma=args.lr_exp_gamma)

    if args.ema > 0:
        ema_model = copy.deepcopy(model)
    else:
        ema_model = None

    if multi_gpu:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)
    if args.pyprof:
        pyprof.init(enable_function_stack=True)

    # load checkpoint
    meta = {'best_wer': 10**6, 'start_epoch': 0}
    checkpointer = Checkpointer(args.output_dir, 'Jasper',
                                args.keep_milestones)
    if args.resume:
        args.ckpt = checkpointer.last_checkpoint() or args.ckpt

    if args.ckpt is not None:
        checkpointer.load(args.ckpt, model, ema_model, optimizer, scaler, meta)

    start_epoch = meta['start_epoch']
    best_wer = meta['best_wer']
    epoch = 1
    step = start_epoch * steps_per_epoch + 1

    if args.pyprof:
        torch.autograd.profiler.emit_nvtx().__enter__()
        profiler.start()

    # training loop
    model.train()

    # pre-allocate
    if args.pre_allocate_range is not None:
        n_feats = train_features_kw['n_filt']
        pad_align = train_features_kw['pad_align']
        a, b = args.pre_allocate_range
        for n_frames in range(a, b + pad_align, pad_align):
            print_once(
                f'Pre-allocation ({batch_size}x{n_feats}x{n_frames})...')

            feat = torch.randn(batch_size, n_feats, n_frames, device='cuda')
            feat_lens = torch.ones(batch_size, device='cuda').fill_(n_frames)
            txt = torch.randint(high=len(symbols) - 1,
                                size=(batch_size, 100),
                                device='cuda')
            txt_lens = torch.ones(batch_size, device='cuda').fill_(100)
            with torch.cuda.amp.autocast(enabled=args.amp):
                log_probs, enc_lens = model(feat, feat_lens)
                del feat
                loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
            loss.backward()
            model.zero_grad()
    torch.cuda.empty_cache()

    bmark_stats = BenchmarkStats()

    for epoch in range(start_epoch + 1, args.epochs + 1):
        if multi_gpu and not use_dali:
            train_loader.sampler.set_epoch(epoch)

        epoch_utts = 0
        epoch_loss = 0
        accumulated_batches = 0
        epoch_start_time = time.time()
        epoch_eval_time = 0

        for batch in train_loader:

            if accumulated_batches == 0:
                step_loss = 0
                step_utts = 0
                step_start_time = time.time()

            if use_dali:
                # with DALI, the data is already on GPU
                feat, feat_lens, txt, txt_lens = batch
                if train_feat_proc is not None:
                    feat, feat_lens = train_feat_proc(feat, feat_lens)
            else:
                batch = [t.cuda(non_blocking=True) for t in batch]
                audio, audio_lens, txt, txt_lens = batch
                feat, feat_lens = train_feat_proc(audio, audio_lens)

            # Use context manager to prevent redundant accumulation of gradients
            if (multi_gpu and
                    accumulated_batches + 1 < args.grad_accumulation_steps):
                ctx = model.no_sync()
            else:
                ctx = empty_context()

            with ctx:
                with torch.cuda.amp.autocast(enabled=args.amp):
                    log_probs, enc_lens = model(feat, feat_lens)

                    loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
                    loss /= args.grad_accumulation_steps

                if multi_gpu:
                    reduced_loss = reduce_tensor(loss.data, world_size)
                else:
                    reduced_loss = loss

                if torch.isnan(reduced_loss).any():
                    print_once(f'WARNING: loss is NaN; skipping update')
                    continue
                else:
                    step_loss += reduced_loss.item()
                    step_utts += batch[0].size(0) * world_size
                    epoch_utts += batch[0].size(0) * world_size
                    accumulated_batches += 1

                    scaler.scale(loss).backward()

            if accumulated_batches % args.grad_accumulation_steps == 0:
                epoch_loss += step_loss
                scaler.step(optimizer)
                scaler.update()

                adjust_lr(step, epoch, optimizer)
                optimizer.zero_grad()

                apply_ema(model, ema_model, args.ema)

                if step % args.log_frequency == 0:
                    preds = greedy_decoder(log_probs)
                    wer, pred_utt, ref = greedy_wer(preds, txt, txt_lens,
                                                    symbols)

                    if step % args.prediction_frequency == 0:
                        print_once(f'  Decoded:   {pred_utt[:90]}')
                        print_once(f'  Reference: {ref[:90]}')

                    step_time = time.time() - step_start_time
                    log(
                        (epoch, step % steps_per_epoch
                         or steps_per_epoch, steps_per_epoch), step, 'train', {
                             'loss': step_loss,
                             'wer': 100.0 * wer,
                             'throughput': step_utts / step_time,
                             'took': step_time,
                             'lrate': optimizer.param_groups[0]['lr']
                         })

                step_start_time = time.time()

                if step % args.eval_frequency == 0:
                    tik = time.time()
                    wer = evaluate(epoch, step, val_loader, val_feat_proc,
                                   symbols, model, ema_model, ctc_loss,
                                   greedy_decoder, args.amp, use_dali)

                    if wer < best_wer and epoch >= args.save_best_from:
                        checkpointer.save(model,
                                          ema_model,
                                          optimizer,
                                          scaler,
                                          epoch,
                                          step,
                                          best_wer,
                                          is_best=True)
                        best_wer = wer
                    epoch_eval_time += time.time() - tik

                step += 1
                accumulated_batches = 0
                # end of step

            # DALI iterator need to be exhausted;
            # if not using DALI, simulate drop_last=True with grad accumulation
            if not use_dali and step > steps_per_epoch * epoch:
                break

        epoch_time = time.time() - epoch_start_time
        epoch_loss /= steps_per_epoch
        log(
            (epoch, ), None, 'train_avg', {
                'throughput': epoch_utts / epoch_time,
                'took': epoch_time,
                'loss': epoch_loss
            })
        bmark_stats.update(epoch_utts, epoch_time, epoch_loss)

        if epoch % args.save_frequency == 0 or epoch in args.keep_milestones:
            checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
                              best_wer)

        if 0 < args.epochs_this_job <= epoch - start_epoch:
            print_once(f'Finished after {args.epochs_this_job} epochs.')
            break
        # end of epoch

    if args.pyprof:
        profiler.stop()
        torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)

    log((), None, 'train_avg', bmark_stats.get(args.benchmark_epochs_num))

    if epoch == args.epochs:
        evaluate(epoch, step, val_loader, val_feat_proc, symbols, model,
                 ema_model, ctc_loss, greedy_decoder, args.amp, use_dali)

        checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
                          best_wer)
    flush_log()
Esempio n. 7
0
def main():
    logging.configure_logger('RNNT')
    logging.log_start(logging.constants.INIT_START)

    args = parse_args()

    assert(torch.cuda.is_available())
    assert args.prediction_frequency is None or args.prediction_frequency % args.log_frequency == 0

    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    # set up distributed training
    multi_gpu = int(os.environ.get('WORLD_SIZE', 1)) > 1
    if multi_gpu:
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend='nccl', init_method='env://')
        world_size = dist.get_world_size()
        print_once(f'Distributed training with {world_size} GPUs\n')
    else:
        world_size = 1

    if args.seed is not None:
        logging.log_event(logging.constants.SEED, value=args.seed)
        torch.manual_seed(args.seed + args.local_rank)
        np.random.seed(args.seed + args.local_rank)
        random.seed(args.seed + args.local_rank)
        # np_rng is used for buckets generation, and needs the same seed on every worker
        np_rng = np.random.default_rng(seed=args.seed)

    init_log(args)

    cfg = config.load(args.model_config)
    config.apply_duration_flags(cfg, args.max_duration)

    assert args.grad_accumulation_steps >= 1
    assert args.batch_size % args.grad_accumulation_steps == 0, f'{args.batch_size} % {args.grad_accumulation_steps} != 0'
    logging.log_event(logging.constants.GRADIENT_ACCUMULATION_STEPS, value=args.grad_accumulation_steps)
    batch_size = args.batch_size // args.grad_accumulation_steps

    logging.log_event(logging.constants.SUBMISSION_BENCHMARK, value=logging.constants.RNNT)
    logging.log_event(logging.constants.SUBMISSION_ORG, value='my-organization')
    logging.log_event(logging.constants.SUBMISSION_DIVISION, value=logging.constants.CLOSED) # closed or open
    logging.log_event(logging.constants.SUBMISSION_STATUS, value=logging.constants.ONPREM) # on-prem/cloud/research
    logging.log_event(logging.constants.SUBMISSION_PLATFORM, value='my platform')

    logging.log_end(logging.constants.INIT_STOP)
    if multi_gpu:
        torch.distributed.barrier()
    logging.log_start(logging.constants.RUN_START)
    if multi_gpu:
        torch.distributed.barrier()

    print_once('Setting up datasets...')
    (
        train_dataset_kw,
        train_features_kw,
        train_splicing_kw,
        train_specaugm_kw,
    ) = config.input(cfg, 'train')
    (
        val_dataset_kw,
        val_features_kw,
        val_splicing_kw,
        val_specaugm_kw,
    ) = config.input(cfg, 'val')

    logging.log_event(logging.constants.DATA_TRAIN_MAX_DURATION,
                      value=train_dataset_kw['max_duration'])
    logging.log_event(logging.constants.DATA_SPEED_PERTURBATON_MAX,
                      value=train_dataset_kw['speed_perturbation']['max_rate'])
    logging.log_event(logging.constants.DATA_SPEED_PERTURBATON_MIN,
                      value=train_dataset_kw['speed_perturbation']['min_rate'])
    logging.log_event(logging.constants.DATA_SPEC_AUGMENT_FREQ_N,
                      value=train_specaugm_kw['freq_masks'])
    logging.log_event(logging.constants.DATA_SPEC_AUGMENT_FREQ_MIN,
                      value=train_specaugm_kw['min_freq'])
    logging.log_event(logging.constants.DATA_SPEC_AUGMENT_FREQ_MAX,
                      value=train_specaugm_kw['max_freq'])
    logging.log_event(logging.constants.DATA_SPEC_AUGMENT_TIME_N,
                      value=train_specaugm_kw['time_masks'])
    logging.log_event(logging.constants.DATA_SPEC_AUGMENT_TIME_MIN,
                      value=train_specaugm_kw['min_time'])
    logging.log_event(logging.constants.DATA_SPEC_AUGMENT_TIME_MAX,
                      value=train_specaugm_kw['max_time'])
    logging.log_event(logging.constants.GLOBAL_BATCH_SIZE,
                      value=batch_size * world_size * args.grad_accumulation_steps)

    tokenizer_kw = config.tokenizer(cfg)
    tokenizer = Tokenizer(**tokenizer_kw)

    class PermuteAudio(torch.nn.Module):
        def forward(self, x):
            return (x[0].permute(2, 0, 1), *x[1:])

    train_augmentations = torch.nn.Sequential(
        train_specaugm_kw and features.SpecAugment(optim_level=args.amp, **train_specaugm_kw) or torch.nn.Identity(),
        features.FrameSplicing(optim_level=args.amp, **train_splicing_kw),
        PermuteAudio(),
    )
    val_augmentations = torch.nn.Sequential(
        val_specaugm_kw and features.SpecAugment(optim_level=args.amp, **val_specaugm_kw) or torch.nn.Identity(),
        features.FrameSplicing(optim_level=args.amp, **val_splicing_kw),
        PermuteAudio(),
    )

    logging.log_event(logging.constants.DATA_TRAIN_NUM_BUCKETS, value=args.num_buckets)

    if args.num_buckets is not None:
        sampler = dali_sampler.BucketingSampler(
            args.num_buckets,
            batch_size,
            world_size,
            args.epochs,
            np_rng
        )
    else:
        sampler = dali_sampler.SimpleSampler()

    train_loader = DaliDataLoader(gpu_id=args.local_rank,
                                  dataset_path=args.dataset_dir,
                                  config_data=train_dataset_kw,
                                  config_features=train_features_kw,
                                  json_names=args.train_manifests,
                                  batch_size=batch_size,
                                  sampler=sampler,
                                  grad_accumulation_steps=args.grad_accumulation_steps,
                                  pipeline_type="train",
                                  device_type=args.dali_device,
                                  tokenizer=tokenizer)

    val_loader = DaliDataLoader(gpu_id=args.local_rank,
                                    dataset_path=args.dataset_dir,
                                    config_data=val_dataset_kw,
                                    config_features=val_features_kw,
                                    json_names=args.val_manifests,
                                    batch_size=args.val_batch_size,
                                    sampler=dali_sampler.SimpleSampler(),
                                    pipeline_type="val",
                                    device_type=args.dali_device,
                                    tokenizer=tokenizer)

    train_feat_proc = train_augmentations
    val_feat_proc   = val_augmentations

    train_feat_proc.cuda()
    val_feat_proc.cuda()

    steps_per_epoch = len(train_loader) // args.grad_accumulation_steps

    logging.log_event(logging.constants.TRAIN_SAMPLES, value=train_loader.dataset_size)
    logging.log_event(logging.constants.EVAL_SAMPLES, value=val_loader.dataset_size)

    # set up the model
    rnnt_config = config.rnnt(cfg)
    logging.log_event(logging.constants.MODEL_WEIGHTS_INITIALIZATION_SCALE, value=args.weights_init_scale)
    if args.weights_init_scale is not None:
        rnnt_config['weights_init_scale'] = args.weights_init_scale
    if args.hidden_hidden_bias_scale is not None:
        rnnt_config['hidden_hidden_bias_scale'] = args.hidden_hidden_bias_scale
    model = RNNT(n_classes=tokenizer.num_labels + 1, **rnnt_config)
    model.cuda()
    blank_idx = tokenizer.num_labels
    loss_fn = RNNTLoss(blank_idx=blank_idx)
    logging.log_event(logging.constants.EVAL_MAX_PREDICTION_SYMBOLS, value=args.max_symbol_per_sample)
    greedy_decoder = RNNTGreedyDecoder( blank_idx=blank_idx,
                                        max_symbol_per_sample=args.max_symbol_per_sample)

    print_once(f'Model size: {num_weights(model) / 10**6:.1f}M params\n')

    opt_eps=1e-9
    logging.log_event(logging.constants.OPT_NAME, value='lamb')
    logging.log_event(logging.constants.OPT_BASE_LR, value=args.lr)
    logging.log_event(logging.constants.OPT_LAMB_EPSILON, value=opt_eps)
    logging.log_event(logging.constants.OPT_LAMB_LR_DECAY_POLY_POWER, value=args.lr_exp_gamma)
    logging.log_event(logging.constants.OPT_LR_WARMUP_EPOCHS, value=args.warmup_epochs)
    logging.log_event(logging.constants.OPT_LAMB_LR_HOLD_EPOCHS, value=args.hold_epochs)
    logging.log_event(logging.constants.OPT_LAMB_BETA_1, value=args.beta1)
    logging.log_event(logging.constants.OPT_LAMB_BETA_2, value=args.beta2)
    logging.log_event(logging.constants.OPT_GRADIENT_CLIP_NORM, value=args.clip_norm)
    logging.log_event(logging.constants.OPT_LR_ALT_DECAY_FUNC, value=True)
    logging.log_event(logging.constants.OPT_LR_ALT_WARMUP_FUNC, value=True)
    logging.log_event(logging.constants.OPT_LAMB_LR_MIN, value=args.min_lr)
    logging.log_event(logging.constants.OPT_WEIGHT_DECAY, value=args.weight_decay)

    # optimization
    kw = {'params': model.param_groups(args.lr), 'lr': args.lr,
          'weight_decay': args.weight_decay}

    initial_lrs = [group['lr'] for group in kw['params']]

    print_once(f'Starting with LRs: {initial_lrs}')
    optimizer = FusedLAMB(betas=(args.beta1, args.beta2), eps=opt_eps, **kw)

    adjust_lr = lambda step, epoch: lr_policy(
        step, epoch, initial_lrs, optimizer, steps_per_epoch=steps_per_epoch,
        warmup_epochs=args.warmup_epochs, hold_epochs=args.hold_epochs,
        min_lr=args.min_lr, exp_gamma=args.lr_exp_gamma)

    if args.amp:
        model, optimizer = amp.initialize(
            models=model,
            optimizers=optimizer,
            opt_level='O1',
            max_loss_scale=512.0)

    if args.ema > 0:
        ema_model = copy.deepcopy(model).cuda()
    else:
        ema_model = None
    logging.log_event(logging.constants.MODEL_EVAL_EMA_FACTOR, value=args.ema)

    if multi_gpu:
        model = DistributedDataParallel(model)

    # load checkpoint
    meta = {'best_wer': 10**6, 'start_epoch': 0}
    checkpointer = Checkpointer(args.output_dir, 'RNN-T',
                                args.keep_milestones, args.amp)
    if args.resume:
        args.ckpt = checkpointer.last_checkpoint() or args.ckpt

    if args.ckpt is not None:
        checkpointer.load(args.ckpt, model, ema_model, optimizer, meta)

    start_epoch = meta['start_epoch']
    best_wer = meta['best_wer']
    last_wer = meta['best_wer']
    epoch = 1
    step = start_epoch * steps_per_epoch + 1

    # training loop
    model.train()
    for epoch in range(start_epoch + 1, args.epochs + 1):

        logging.log_start(logging.constants.BLOCK_START,
                          metadata=dict(first_epoch_num=epoch,
                                        epoch_count=1))
        logging.log_start(logging.constants.EPOCH_START,
                          metadata=dict(epoch_num=epoch))

        epoch_utts = 0
        accumulated_batches = 0
        epoch_start_time = time.time()

        for batch in train_loader:

            if accumulated_batches == 0:
                adjust_lr(step, epoch)
                optimizer.zero_grad()
                step_utts = 0
                step_start_time = time.time()
                all_feat_lens = []

            audio, audio_lens, txt, txt_lens = batch

            feats, feat_lens = train_feat_proc([audio, audio_lens])
            all_feat_lens += feat_lens

            log_probs, log_prob_lens = model(feats, feat_lens, txt, txt_lens)
            loss = loss_fn(log_probs[:, :log_prob_lens.max().item()],
                                      log_prob_lens, txt, txt_lens)

            loss /= args.grad_accumulation_steps

            del log_probs, log_prob_lens

            if torch.isnan(loss).any():
                print_once(f'WARNING: loss is NaN; skipping update')
            else:
                if args.amp:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                loss_item = loss.item()
                del loss
                step_utts += batch[0].size(0) * world_size
                epoch_utts += batch[0].size(0) * world_size
                accumulated_batches += 1

            if accumulated_batches % args.grad_accumulation_steps == 0:

                if args.clip_norm is not None:
                    torch.nn.utils.clip_grad_norm_(
                        getattr(model, 'module', model).parameters(),
                        max_norm=args.clip_norm,
                        norm_type=2)

                total_norm = 0.0

                try:
                    if args.log_norm:
                        for p in getattr(model, 'module', model).parameters():
                            param_norm = p.grad.data.norm(2)
                            total_norm += param_norm.item() ** 2
                        total_norm = total_norm ** (1. / 2)
                except AttributeError as e:
                    print_once(f'Exception happened: {e}')
                    total_norm = 0.0

                optimizer.step()
                apply_ema(model, ema_model, args.ema)

                if step % args.log_frequency == 0:

                    if args.prediction_frequency is None or step % args.prediction_frequency == 0:
                        preds = greedy_decoder.decode(model, feats, feat_lens)
                        wer, pred_utt, ref = greedy_wer(
                                preds,
                                txt,
                                txt_lens,
                                tokenizer.detokenize)
                        print_once(f'  Decoded:   {pred_utt[:90]}')
                        print_once(f'  Reference: {ref[:90]}')
                        wer = {'wer': 100 * wer}
                    else:
                        wer = {}

                    step_time = time.time() - step_start_time

                    log((epoch, step % steps_per_epoch or steps_per_epoch, steps_per_epoch),
                        step, 'train',
                        {'loss': loss_item,
                         **wer,  # optional entry
                         'throughput': step_utts / step_time,
                         'took': step_time,
                         'grad-norm': total_norm,
                         'seq-len-min': min(all_feat_lens).item(),
                         'seq-len-max': max(all_feat_lens).item(),
                         'lrate': optimizer.param_groups[0]['lr']})

                step_start_time = time.time()

                step += 1
                accumulated_batches = 0
                # end of step

        logging.log_end(logging.constants.EPOCH_STOP,
                        metadata=dict(epoch_num=epoch))

        epoch_time = time.time() - epoch_start_time
        log((epoch,), None, 'train_avg', {'throughput': epoch_utts / epoch_time,
                                          'took': epoch_time})

        if epoch % args.val_frequency == 0:
            wer = evaluate(epoch, step, val_loader, val_feat_proc,
                           tokenizer.detokenize, ema_model, loss_fn,
                           greedy_decoder, args.amp)

            last_wer = wer
            if wer < best_wer and epoch >= args.save_best_from:
                checkpointer.save(model, ema_model, optimizer, epoch,
                                  step, best_wer, is_best=True)
                best_wer = wer

        save_this_epoch = (args.save_frequency is not None and epoch % args.save_frequency == 0) \
                       or (epoch in args.keep_milestones)
        if save_this_epoch:
            checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer)

        logging.log_end(logging.constants.BLOCK_STOP, metadata=dict(first_epoch_num=epoch))

        if last_wer <= args.target:
            logging.log_end(logging.constants.RUN_STOP, metadata={'status': 'success'})
            print_once(f'Finished after {args.epochs_this_job} epochs.')
            break
        if 0 < args.epochs_this_job <= epoch - start_epoch:
            print_once(f'Finished after {args.epochs_this_job} epochs.')
            break
        # end of epoch

    log((), None, 'train_avg', {'throughput': epoch_utts / epoch_time})

    if last_wer > args.target:
        logging.log_end(logging.constants.RUN_STOP, metadata={'status': 'aborted'})

    if epoch == args.epochs:
        evaluate(epoch, step, val_loader, val_feat_proc, tokenizer.detokenize,
                 ema_model, loss_fn, greedy_decoder, args.amp)

    flush_log()
    if args.save_at_the_end:
        checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer)