예제 #1
0
def evaluate(epoch, step, val_loader, val_feat_proc, detokenize,
             ema_model, loss_fn, greedy_decoder, use_amp):

    ema_model.eval()

    start_time = time.time()
    agg = {'losses': [], 'preds': [], 'txts': [], 'idx': []}
    logging.log_start(logging.constants.EVAL_START, metadata=dict(epoch_num=epoch))
    for i, batch in enumerate(val_loader):
        print(f'{val_loader.pipeline_type} evaluation: {i:>10}/{len(val_loader):<10}', end='\r')

        audio, audio_lens, txt, txt_lens = batch

        feats, feat_lens = val_feat_proc([audio, audio_lens])

        log_probs, log_prob_lens = ema_model(feats, feat_lens, txt, txt_lens)
        loss = loss_fn(log_probs[:, :log_prob_lens.max().item()],
                                  log_prob_lens, txt, txt_lens)

        pred = greedy_decoder.decode(ema_model, feats, feat_lens)

        agg['losses'] += helpers.gather_losses([loss.cpu()])
        agg['preds'] += helpers.gather_predictions([pred], detokenize)
        agg['txts'] += helpers.gather_transcripts([txt.cpu()], [txt_lens.cpu()], detokenize)

    wer, loss = process_evaluation_epoch(agg)

    logging.log_event(logging.constants.EVAL_ACCURACY, value=wer, metadata=dict(epoch_num=epoch))
    logging.log_end(logging.constants.EVAL_STOP, metadata=dict(epoch_num=epoch))

    log((epoch,), step, 'dev_ema', {'loss': loss, 'wer': 100.0 * wer, 'took': time.time() - start_time})
    ema_model.train()
    return wer
예제 #2
0
def evaluate(epoch, step, val_loader, val_feat_proc, labels, model,
             ema_model, ctc_loss, greedy_decoder, use_amp, use_dali=False):

    for model, subset in [(model, 'dev'), (ema_model, 'dev_ema')]:
        if model is None:
            continue

        model.eval()
        start_time = time.time()
        agg = {'losses': [], 'preds': [], 'txts': []}

        for batch in val_loader:
            if use_dali:
                # with DALI, the data is already on GPU
                feat, feat_lens, txt, txt_lens = batch
                if val_feat_proc is not None:
                    feat, feat_lens = val_feat_proc(feat, feat_lens, use_amp)
            else:
                batch = [t.cuda(non_blocking=True) for t in batch]
                audio, audio_lens, txt, txt_lens = batch
                feat, feat_lens = val_feat_proc(audio, audio_lens, use_amp)

            log_probs, enc_lens = model.forward(feat, feat_lens)
            loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
            pred = greedy_decoder(log_probs)

            agg['losses'] += helpers.gather_losses([loss])
            agg['preds'] += helpers.gather_predictions([pred], labels)
            agg['txts'] += helpers.gather_transcripts([txt], [txt_lens], labels)

        wer, loss = process_evaluation_epoch(agg)
        log((epoch,), step, subset, {'loss': loss, 'wer': 100.0 * wer,
                                     'took': time.time() - start_time})
        model.train()
    return wer
예제 #3
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    log_fpath = args.log_file or str(Path(args.output_dir, 'nvlog_infer.json'))
    log_fpath = unique_log_fpath(log_fpath)
    dllogger.init(backends=[
        JSONStreamBackend(Verbosity.DEFAULT, log_fpath),
        StdOutBackend(Verbosity.VERBOSE, metric_format=stdout_metric_format)
    ])

    [dllogger.log("PARAMETER", {k: v}) for k, v in vars(args).items()]

    for step in ['DNN', 'data+DNN', 'data']:
        for c in [0.99, 0.95, 0.9, 0.5]:
            cs = 'avg' if c == 0.5 else f'{int(100*c)}%'
            dllogger.metadata(f'{step.lower()}_latency_{c}', {
                'name': f'{step} latency {cs}',
                'format': ':>7.2f',
                'unit': 'ms'
            })
    dllogger.metadata('eval_wer', {
        'name': 'WER',
        'format': ':>3.3f',
        'unit': '%'
    })

    if args.cpu:
        device = torch.device('cpu')
    else:
        assert torch.cuda.is_available()
        device = torch.device('cuda')
        torch.backends.cudnn.benchmark = args.cudnn_benchmark

    if args.seed is not None:
        torch.manual_seed(args.seed + args.local_rank)
        np.random.seed(args.seed + args.local_rank)
        random.seed(args.seed + args.local_rank)

    # set up distributed training
    multi_gpu = not args.cpu and int(os.environ.get('WORLD_SIZE', 1)) > 1
    if multi_gpu:
        torch.cuda.set_device(args.local_rank)
        distrib.init_process_group(backend='nccl', init_method='env://')
        print_once(f'Inference with {distrib.get_world_size()} GPUs')

    cfg = config.load(args.model_config)

    if args.max_duration is not None:
        cfg['input_val']['audio_dataset']['max_duration'] = args.max_duration
        cfg['input_val']['filterbank_features'][
            'max_duration'] = args.max_duration

    if args.pad_to_max_duration:
        assert cfg['input_val']['audio_dataset']['max_duration'] > 0
        cfg['input_val']['audio_dataset']['pad_to_max_duration'] = True
        cfg['input_val']['filterbank_features']['pad_to_max_duration'] = True

    use_dali = args.dali_device in ('cpu', 'gpu')

    (dataset_kw, features_kw, splicing_kw, _, _) = config.input(cfg, 'val')

    tokenizer_kw = config.tokenizer(cfg)
    tokenizer = Tokenizer(**tokenizer_kw)

    optim_level = 3 if args.amp else 0

    feature_proc = torch.nn.Sequential(
        torch.nn.Identity(),
        torch.nn.Identity(),
        features.FrameSplicing(optim_level=optim_level, **splicing_kw),
        features.FillPadding(optim_level=optim_level, ),
    )

    # dataset

    data_loader = DaliDataLoader(gpu_id=args.local_rank or 0,
                                 dataset_path=args.dataset_dir,
                                 config_data=dataset_kw,
                                 config_features=features_kw,
                                 json_names=[args.val_manifest],
                                 batch_size=args.batch_size,
                                 sampler=dali_sampler.SimpleSampler(),
                                 pipeline_type="val",
                                 device_type=args.dali_device,
                                 tokenizer=tokenizer)

    model = RNNT(n_classes=tokenizer.num_labels + 1, **config.rnnt(cfg))

    if args.ckpt is not None:
        print(f'Loading the model from {args.ckpt} ...')
        checkpoint = torch.load(args.ckpt, map_location="cpu")
        key = 'ema_state_dict' if args.ema else 'state_dict'
        state_dict = checkpoint[key]
        model.load_state_dict(state_dict, strict=True)

    model.to(device)
    model.eval()

    if feature_proc is not None:
        feature_proc.to(device)
        feature_proc.eval()

    if args.amp:
        model = amp.initialize(model, opt_level='O3')

    if multi_gpu:
        model = DistributedDataParallel(model)

    agg = {'txts': [], 'preds': [], 'logits': []}
    dur = {'data': [], 'dnn': [], 'data+dnn': []}

    rep_loader = chain(*repeat(data_loader, args.repeats))
    rep_len = args.repeats * len(data_loader)

    blank_idx = tokenizer.num_labels
    greedy_decoder = RNNTGreedyDecoder(blank_idx=blank_idx)

    def sync_time():
        torch.cuda.synchronize() if device.type == 'cuda' else None
        return time.perf_counter()

    sz = []
    with torch.no_grad():

        for it, batch in enumerate(tqdm.tqdm(rep_loader, total=rep_len)):

            if use_dali:
                feats, feat_lens, txt, txt_lens = batch
                if feature_proc is not None:
                    feats, feat_lens = feature_proc([feats, feat_lens])
            else:
                batch = [t.cuda(non_blocking=True) for t in batch]
                audio, audio_lens, txt, txt_lens = batch
                feats, feat_lens = feature_proc([audio, audio_lens])
            feats = feats.permute(2, 0, 1)
            if args.amp:
                feats = feats.half()

            sz.append(feats.size(0))

            t1 = sync_time()
            log_probs, log_prob_lens = model(feats, feat_lens, txt, txt_lens)
            t2 = sync_time()

            # burn-in period; wait for a new loader due to num_workers
            if it >= 1 and (args.repeats == 1 or it >= len(data_loader)):
                dur['data'].append(t1 - t0)
                dur['dnn'].append(t2 - t1)
                dur['data+dnn'].append(t2 - t0)

            if txt is not None:
                agg['txts'] += helpers.gather_transcripts([txt], [txt_lens],
                                                          tokenizer.detokenize)

            preds = greedy_decoder.decode(model, feats, feat_lens)

            agg['preds'] += helpers.gather_predictions([preds],
                                                       tokenizer.detokenize)

            if 0 < args.steps < it:
                break

            t0 = sync_time()

        # communicate the results
        if args.transcribe_wav:
            for idx, p in enumerate(agg['preds']):
                print_once(f'Prediction {idx+1: >3}: {p}')

        elif args.transcribe_filelist:
            pass

        else:
            wer, loss = process_evaluation_epoch(agg)

            if not multi_gpu or distrib.get_rank() == 0:
                dllogger.log(step=(), data={'eval_wer': 100 * wer})

        if args.save_predictions:
            with open(args.save_predictions, 'w') as f:
                f.write('\n'.join(agg['preds']))

    # report timings
    if len(dur['data']) >= 20:
        ratios = [0.9, 0.95, 0.99]

        for stage in dur:
            lat = durs_to_percentiles(dur[stage], ratios)
            for k in [0.99, 0.95, 0.9, 0.5]:
                kk = str(k).replace('.', '_')
                dllogger.log(step=(),
                             data={f'{stage.lower()}_latency_{kk}': lat[k]})

    else:
        # TODO measure at least avg latency
        print_once('Not enough samples to measure latencies.')
예제 #4
0
def main():

    parser = get_parser()
    args = parser.parse_args()

    log_fpath = args.log_file or str(Path(args.output_dir, 'nvlog_infer.json'))
    log_fpath = unique_log_fpath(log_fpath)
    dllogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT, log_fpath),
                            StdOutBackend(Verbosity.VERBOSE,
                                          metric_format=stdout_metric_format)])

    [dllogger.log("PARAMETER", {k: v}) for k, v in vars(args).items()]

    for step in ['DNN', 'data+DNN', 'data']:
        for c in [0.99, 0.95, 0.9, 0.5]:
            cs = 'avg' if c == 0.5 else f'{int(100*c)}%'
            dllogger.metadata(f'{step.lower()}_latency_{c}',
                              {'name': f'{step} latency {cs}',
                               'format': ':>7.2f', 'unit': 'ms'})
    dllogger.metadata(
        'eval_wer', {'name': 'WER', 'format': ':>3.2f', 'unit': '%'})

    if args.cpu:
        device = torch.device('cpu')
    else:
        assert torch.cuda.is_available()
        device = torch.device('cuda')
        torch.backends.cudnn.benchmark = args.cudnn_benchmark

    if args.seed is not None:
        torch.manual_seed(args.seed + args.local_rank)
        np.random.seed(args.seed + args.local_rank)
        random.seed(args.seed + args.local_rank)

    # set up distributed training
    multi_gpu = not args.cpu and int(os.environ.get('WORLD_SIZE', 1)) > 1
    if multi_gpu:
        torch.cuda.set_device(args.local_rank)
        distrib.init_process_group(backend='nccl', init_method='env://')
        print_once(f'Inference with {distrib.get_world_size()} GPUs')

    cfg = config.load(args.model_config)
    config.apply_config_overrides(cfg, args)

    symbols = helpers.add_ctc_blank(cfg['labels'])

    use_dali = args.dali_device in ('cpu', 'gpu')
    dataset_kw, features_kw = config.input(cfg, 'val')

    measure_perf = args.steps > 0

    # dataset
    if args.transcribe_wav or args.transcribe_filelist:

        if use_dali:
            print("DALI supported only with input .json files; disabling")
            use_dali = False

        assert not args.pad_to_max_duration
        assert not (args.transcribe_wav and args.transcribe_filelist)

        if args.transcribe_wav:
            dataset = SingleAudioDataset(args.transcribe_wav)
        else:
            dataset = FilelistDataset(args.transcribe_filelist)

        data_loader = get_data_loader(dataset,
                                      batch_size=1,
                                      multi_gpu=multi_gpu,
                                      shuffle=False,
                                      num_workers=0,
                                      drop_last=(True if measure_perf else False))

        _, features_kw = config.input(cfg, 'val')
        feat_proc = FilterbankFeatures(**features_kw)

    elif use_dali:
        # pad_to_max_duration is not supported by DALI - have simple padders
        if features_kw['pad_to_max_duration']:
            feat_proc = BaseFeatures(
                pad_align=features_kw['pad_align'],
                pad_to_max_duration=True,
                max_duration=features_kw['max_duration'],
                sample_rate=features_kw['sample_rate'],
                window_size=features_kw['window_size'],
                window_stride=features_kw['window_stride'])
            features_kw['pad_to_max_duration'] = False
        else:
            feat_proc = None

        data_loader = DaliDataLoader(
            gpu_id=args.local_rank or 0,
            dataset_path=args.dataset_dir,
            config_data=dataset_kw,
            config_features=features_kw,
            json_names=args.val_manifests,
            batch_size=args.batch_size,
            pipeline_type=("train" if measure_perf else "val"),  # no drop_last
            device_type=args.dali_device,
            symbols=symbols)

    else:
        dataset = AudioDataset(args.dataset_dir,
                               args.val_manifests,
                               symbols,
                               **dataset_kw)

        data_loader = get_data_loader(dataset,
                                      args.batch_size,
                                      multi_gpu=multi_gpu,
                                      shuffle=False,
                                      num_workers=4,
                                      drop_last=False)

        feat_proc = FilterbankFeatures(**features_kw)

    model = QuartzNet(encoder_kw=config.encoder(cfg),
                      decoder_kw=config.decoder(cfg, n_classes=len(symbols)))

    if args.ckpt is not None:
        print(f'Loading the model from {args.ckpt} ...')
        checkpoint = torch.load(args.ckpt, map_location="cpu")
        key = 'ema_state_dict' if args.ema else 'state_dict'
        state_dict = checkpoint[key]
        model.load_state_dict(state_dict, strict=True)

    model.to(device)
    model.eval()

    if feat_proc is not None:
        feat_proc.to(device)
        feat_proc.eval()

    if args.amp:
        model = model.half()

    if args.torchscript:
        greedy_decoder = GreedyCTCDecoder()

        feat_proc, model, greedy_decoder = torchscript_export(
            data_loader, feat_proc, model, greedy_decoder, args.output_dir,
            use_amp=args.amp, use_conv_masks=True, model_toml=args.model_toml,
            device=device, save=args.torchscript_export)

    if multi_gpu:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)

    agg = {'txts': [], 'preds': [], 'logits': []}
    dur = {'data': [], 'dnn': [], 'data+dnn': []}

    looped_loader = chain.from_iterable(repeat(data_loader))
    greedy_decoder = GreedyCTCDecoder()

    sync = lambda: torch.cuda.synchronize() if device.type == 'cuda' else None

    steps = args.steps + args.warmup_steps or len(data_loader)
    with torch.no_grad():

        for it, batch in enumerate(tqdm(looped_loader, initial=1, total=steps)):

            if use_dali:
                feats, feat_lens, txt, txt_lens = batch
                if feat_proc is not None:
                    feats, feat_lens = feat_proc(feats, feat_lens)
            else:
                batch = [t.to(device, non_blocking=True) for t in batch]
                audio, audio_lens, txt, txt_lens = batch
                feats, feat_lens = feat_proc(audio, audio_lens)

            sync()
            t1 = time.perf_counter()

            if args.amp:
                feats = feats.half()

            if model.encoder.use_conv_masks:
                log_probs, log_prob_lens = model(feats, feat_lens)
            else:
                log_probs = model(feats, feat_lens)

            preds = greedy_decoder(log_probs)

            sync()
            t2 = time.perf_counter()

            # burn-in period; wait for a new loader due to num_workers
            if it >= 1 and (args.steps == 0 or it >= args.warmup_steps):
                dur['data'].append(t1 - t0)
                dur['dnn'].append(t2 - t1)
                dur['data+dnn'].append(t2 - t0)

            if txt is not None:
                agg['txts'] += helpers.gather_transcripts([txt], [txt_lens],
                                                          symbols)
            agg['preds'] += helpers.gather_predictions([preds], symbols)
            agg['logits'].append(log_probs)

            if it + 1 == steps:
                break

            sync()
            t0 = time.perf_counter()

        # communicate the results
        if args.transcribe_wav:
            for idx, p in enumerate(agg['preds']):
                print_once(f'Prediction {idx+1: >3}: {p}')

        elif args.transcribe_filelist:
            pass

        elif not multi_gpu or distrib.get_rank() == 0:
            wer, _ = process_evaluation_epoch(agg)

            dllogger.log(step=(), data={'eval_wer': 100 * wer})

        if args.save_predictions:
            with open(args.save_predictions, 'w') as f:
                f.write('\n'.join(agg['preds']))

        if args.save_logits:
            logits = torch.cat(agg['logits'], dim=0).cpu()
            torch.save(logits, args.save_logits)

    # report timings
    if len(dur['data']) >= 20:
        ratios = [0.9, 0.95, 0.99]
        for stage in dur:
            lat = durs_to_percentiles(dur[stage], ratios)
            for k in [0.99, 0.95, 0.9, 0.5]:
                kk = str(k).replace('.', '_')
                dllogger.log(step=(), data={f'{stage.lower()}_latency_{kk}': lat[k]})

    else:
        print_once('Not enough samples to measure latencies.')
예제 #5
0
def evaluate(conf, onnx_path, val_loader, val_feat_proc, inference_session,
             inference_anchors, inference_inputs, inference_transcription_out,
             inference_transcription_out_lens, pytorch_rnnt_model,
             greedy_decoder, detokenize):

    start_time = time.time()

    logger.info(
        "Getting trained weights for inference from {}".format(onnx_path))
    inference_session.resetHostWeights(
        onnx_path, ignoreWeightsInModelWithoutCorrespondingHostWeight=True)
    inference_session.weightsFromHost()

    feats_data = []
    feat_lens_data = []
    txt_data = []
    txt_lens_data = []

    samples_per_step_per_instance = conf.samples_per_step // conf.num_instances

    # No need to compute losses for validation
    agg = {'preds': [], 'txts': [], 'idx': []}
    overall_scores, overall_words = (0, 0)
    logger.info("Running transcription network on evaluation dataset")
    for audio, audio_lens, txt, txt_lens in val_loader:
        feats, feat_lens = val_feat_proc([audio, audio_lens])

        feats = feats.numpy()
        feat_lens = feat_lens.numpy()
        # txt is of np.array type as implemented in reference code
        txt_lens = txt_lens.numpy()

        feats = pad(feats, samples_per_step_per_instance)
        feat_lens = pad(feat_lens, samples_per_step_per_instance)
        txt = pad(txt, samples_per_step_per_instance)
        txt_lens = pad(txt_lens, samples_per_step_per_instance)

        stepio = popart.PyStepIO(
            {
                inference_inputs["mel_spec_input"]: feats.astype(
                    conf.precision),
                inference_inputs["input_length"]: feat_lens.astype(np.int32),
            }, inference_anchors)

        inference_session.run(stepio)

        # converting to torch tensor
        feats = torch.tensor(inference_anchors[inference_transcription_out])
        feat_lens = torch.tensor(
            inference_anchors[inference_transcription_out_lens])
        feat_lens = torch.flatten(feat_lens)
        step_size = feat_lens.shape[0]
        feats = torch.reshape(feats,
                              (step_size, feats.shape[-2], feats.shape[-1]))

        txt = torch.tensor(txt)
        txt_lens = torch.tensor(txt_lens)

        feats_data.append(feats)
        feat_lens_data.append(feat_lens)
        txt_data.append(txt)
        txt_lens_data.append(txt_lens)

    num_cpus = os.cpu_count()
    num_workers = max(1, min(16, num_cpus // conf.num_instances))
    logger.info(
        "Creating multiprocessor pool with {} workers and function for decoding"
        .format(num_workers))
    greedy_decoding_processor_pool = multiprocessing.pool.ThreadPool(
        processes=num_workers)
    greedy_decoding_func = partial(greedy_decoder.decode, pytorch_rnnt_model)

    pred_results = []
    ground_truths_dekotenized = []
    logger.info("Submitting jobs for greedy decoding")

    feat_iter = zip(feats_data, feat_lens_data, txt_data, txt_lens_data)
    for feats, feat_lens, txt, txt_lens in feat_iter:

        step_size = feat_lens.shape[0]
        batch_size = step_size // conf.batches_per_step

        for bind in range(conf.batches_per_step):
            feats_b = feats[bind:bind + batch_size]
            feat_lens_b = feat_lens[bind:bind + batch_size]
            txt_b = txt[bind:bind + batch_size]
            txt_lens_b = txt_lens[bind:bind + batch_size]

            pred_results.append(
                greedy_decoding_processor_pool.apply_async(
                    greedy_decoding_func, (feats_b, feat_lens_b)))
            ground_truths_dekotenized.append(
                helpers.gather_transcripts([txt_b], [txt_lens_b], detokenize))

    logger.info("Generating predictions and computing Word Error Rate (WER)")
    pred_iter = zip(pred_results, ground_truths_dekotenized)

    for idx, (pred_result, gts_detokenized) in enumerate(pred_iter):
        preds_detokenized = helpers.gather_predictions([pred_result.get()],
                                                       detokenize)

        batch_wer, batch_scores, batch_words = metrics.word_error_rate(
            preds_detokenized, gts_detokenized)

        agg['preds'] += preds_detokenized
        agg['txts'] += gts_detokenized

    wer, scores, num_words, _ = helpers.process_evaluation_epoch(agg)
    logger.info(
        "Total time for Transducer Decoding = {:.1f} secs".format(time.time() -
                                                                  start_time))

    greedy_decoding_processor_pool.close()
    greedy_decoding_processor_pool.join()

    return wer, scores, num_words