コード例 #1
0
data_layer = nemo_asr.AudioToTextDataLayer(manifest_filepath=train_dataset,
                                           labels=labels,
                                           batch_size=16,
                                           manifest_class=ManifestENRU,
                                           num_workers=8)
data_layer_val = nemo_asr.AudioToTextDataLayer(manifest_filepath=eval_datasets,
                                               labels=labels,
                                               batch_size=1,
                                               shuffle=False,
                                               manifest_class=ManifestENRU,
                                               num_workers=8)

data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor()
spec_augment = nemo_asr.SpectrogramAugmentation(rect_masks=5)

jasper_encoder = nemo_asr.JasperEncoder(
    feat_in=64, **jasper_model_definition['JasperEncoder'])
jasper_encoder.restore_from('./chkp/JasperEncoder-STEP-247400.pt')
jasper_decoder = nemo_asr.JasperDecoderForCTC(feat_in=1024,
                                              num_classes=len(labels))
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(labels))
greedy_decoder = nemo_asr.GreedyCTCDecoder()

# Training DAG (Model)
audio_signal, audio_signal_len, transcript, transcript_len = data_layer()
processed_signal, processed_signal_len = data_preprocessor(
    input_signal=audio_signal, length=audio_signal_len)
aug_signal = spec_augment(input_spec=processed_signal)
encoded, encoded_len = jasper_encoder(audio_signal=aug_signal,
                                      length=processed_signal_len)
log_probs = jasper_decoder(encoder_output=encoded)
predictions = greedy_decoder(log_probs=log_probs)
コード例 #2
0
def main():
    parser = argparse.ArgumentParser(description='Jasper')
    parser.add_argument("--local_rank", default=None, type=int)
    parser.add_argument("--batch_size", default=32, type=int)
    parser.add_argument("--model_config", type=str, required=True)
    parser.add_argument("--eval_datasets", type=str, required=True)
    parser.add_argument("--load_dir", type=str, required=True)
    parser.add_argument("--vocab_file", type=str, required=True)
    parser.add_argument("--save_logprob", default=None, type=str)
    parser.add_argument("--lm_path", default=None, type=str)
    parser.add_argument("--beam_width", default=50, type=int)
    parser.add_argument("--alpha", default=2.0, type=float)
    parser.add_argument("--beta", default=1.0, type=float)
    parser.add_argument("--cutoff_prob", default=0.99, type=float)
    parser.add_argument("--cutoff_top_n", default=40, type=int)

    args = parser.parse_args()
    batch_size = args.batch_size
    load_dir = args.load_dir

    if args.local_rank is not None:
        if args.lm_path:
            raise NotImplementedError(
                "Beam search decoder with LM does not currently support "
                "evaluation on multi-gpu.")
        device = nemo.core.DeviceType.AllGpu
    else:
        device = nemo.core.DeviceType.GPU

    # Instantiate Neural Factory with supported backend
    neural_factory = nemo.core.NeuralModuleFactory(
        backend=nemo.core.Backend.PyTorch,
        local_rank=args.local_rank,
        optimization_level=nemo.core.Optimization.mxprO1,
        placement=device)
    logger = neural_factory.logger

    if args.local_rank is not None:
        logger.info('Doing ALL GPU')

    yaml = YAML(typ="safe")
    with open(args.model_config) as f:
        jasper_params = yaml.load(f)

    vocab = load_vocab(args.vocab_file)

    sample_rate = jasper_params['sample_rate']

    eval_datasets = args.eval_datasets

    eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
    eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
    eval_dl_params["normalize_transcripts"] = False
    del eval_dl_params["train"]
    del eval_dl_params["eval"]
    data_layer = nemo_asr.AudioToTextDataLayer(manifest_filepath=eval_datasets,
                                               sample_rate=sample_rate,
                                               labels=vocab,
                                               batch_size=batch_size,
                                               **eval_dl_params)

    n = len(data_layer)
    logger.info('Evaluating {0} examples'.format(n))

    data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
        sample_rate=sample_rate,
        **jasper_params["AudioToMelSpectrogramPreprocessor"])
    jasper_encoder = nemo_asr.JasperEncoder(
        feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
        **jasper_params["JasperEncoder"])
    jasper_decoder = nemo_asr.JasperDecoderForCTC(
        feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
        num_classes=len(vocab))
    greedy_decoder = nemo_asr.GreedyCTCDecoder()

    if args.lm_path:
        beam_width = args.beam_width
        alpha = args.alpha
        beta = args.beta
        cutoff_prob = args.cutoff_prob
        cutoff_top_n = args.cutoff_top_n
        beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM(
            vocab=vocab,
            beam_width=beam_width,
            alpha=alpha,
            beta=beta,
            cutoff_prob=cutoff_prob,
            cutoff_top_n=cutoff_top_n,
            lm_path=args.lm_path,
            num_cpus=max(os.cpu_count(), 1))

    logger.info('================================')
    logger.info(
        f"Number of parameters in encoder: {jasper_encoder.num_weights}")
    logger.info(
        f"Number of parameters in decoder: {jasper_decoder.num_weights}")
    logger.info(f"Total number of parameters in decoder: "
                f"{jasper_decoder.num_weights + jasper_encoder.num_weights}")
    logger.info('================================')

    audio_signal_e1, a_sig_length_e1, transcript_e1, transcript_len_e1 = \
        data_layer()
    processed_signal_e1, p_length_e1 = data_preprocessor(
        input_signal=audio_signal_e1, length=a_sig_length_e1)
    encoded_e1, encoded_len_e1 = jasper_encoder(
        audio_signal=processed_signal_e1, length=p_length_e1)
    log_probs_e1 = jasper_decoder(encoder_output=encoded_e1)
    predictions_e1 = greedy_decoder(log_probs=log_probs_e1)

    eval_tensors = [
        log_probs_e1, predictions_e1, transcript_e1, transcript_len_e1,
        encoded_len_e1
    ]

    if args.lm_path:
        beam_predictions_e1 = beam_search_with_lm(
            log_probs=log_probs_e1, log_probs_length=encoded_len_e1)
        eval_tensors.append(beam_predictions_e1)

    evaluated_tensors = neural_factory.infer(
        tensors=eval_tensors,
        checkpoint_dir=load_dir,
    )

    greedy_hypotheses = post_process_predictions(evaluated_tensors[1], vocab)
    references = post_process_transcripts(evaluated_tensors[2],
                                          evaluated_tensors[3], vocab)
    cer = word_error_rate(hypotheses=greedy_hypotheses,
                          references=references,
                          use_cer=True)
    logger.info("Greedy CER {:.2f}%".format(cer * 100))

    if args.lm_path:
        beam_hypotheses = []
        # Over mini-batch
        for i in evaluated_tensors[-1]:
            # Over samples
            for j in i:
                beam_hypotheses.append(j[0][1])

        cer = word_error_rate(hypotheses=beam_hypotheses,
                              references=references,
                              use_cer=True)
        logger.info("Beam CER {:.2f}".format(cer * 100))

    if args.save_logprob:
        # Convert logits to list of numpy arrays
        logprob = []
        for i, batch in enumerate(evaluated_tensors[0]):
            for j in range(batch.shape[0]):
                logprob.append(
                    batch[j][:evaluated_tensors[4][i][j], :].cpu().numpy())
        with open(args.save_logprob, 'wb') as f:
            pickle.dump(logprob, f, protocol=pickle.HIGHEST_PROTOCOL)
コード例 #3
0
def main():
    parser = argparse.ArgumentParser(description='Jasper')
    # model params
    parser.add_argument("--model_config", type=str, required=True)
    parser.add_argument("--eval_datasets", type=str, required=True)
    parser.add_argument("--load_dir", type=str, required=True)
    parser.add_argument("--model_id", type=str, required=True) # mine
    # run params
    parser.add_argument("--local_rank", default=None, type=int)
    parser.add_argument("--batch_size", default=64, type=int)
    parser.add_argument("--amp_opt_level", default="O0", type=str) # mine
    # store results
    parser.add_argument("--save_results", default=None, type=str) # mine

    # lm inference parameters
    parser.add_argument("--lm_path", default=None, type=str)
    parser.add_argument(
        '--alpha', default=2., type=float,
        help='value of LM weight',
        required=False)
    parser.add_argument(
        '--alpha_max', type=float,
        help='maximum value of LM weight (for a grid search in \'eval\' mode)',
        required=False)
    parser.add_argument(
        '--alpha_step', type=float,
        help='step for LM weight\'s tuning in \'eval\' mode',
        required=False, default=0.1)
    parser.add_argument(
        '--beta', default=1.5, type=float,
        help='value of word count weight',
        required=False)
    parser.add_argument(
        '--beta_max', type=float,
        help='maximum value of word count weight (for a grid search in \
          \'eval\' mode',
        required=False)
    parser.add_argument(
        '--beta_step', type=float,
        help='step for word count weight\'s tuning in \'eval\' mode',
        required=False, default=0.1)
    parser.add_argument(
        "--beam_width", default=128, type=int)

    args = parser.parse_args()
    batch_size = args.batch_size
    load_dir = args.load_dir

    if args.local_rank is not None:
        if args.lm_path:
            raise NotImplementedError(
                "Beam search decoder with LM does not currently support "
                "evaluation on multi-gpu.")
        device = nemo.core.DeviceType.AllGpu
    else:
        device = nemo.core.DeviceType.GPU

    # Instantiate Neural Factory with supported backend
    neural_factory = nemo.core.NeuralModuleFactory(
        backend=nemo.core.Backend.PyTorch,
        local_rank=args.local_rank,
        optimization_level=args.amp_opt_level,
        placement=device)
    logger = neural_factory.logger

    if args.local_rank is not None:
        logger.info('Doing ALL GPU')

    yaml = YAML(typ="safe")
    with open(args.model_config) as f:
        jasper_params = yaml.load(f)
    vocab = jasper_params['labels']
    sample_rate = jasper_params['sample_rate']

    # single eval dataset
    eval_datasets = args.eval_datasets
    eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
    eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
    del eval_dl_params["train"]
    del eval_dl_params["eval"]
    data_layer = nemo_asr.AudioToTextDataLayer(
        manifest_filepath=eval_datasets,
        sample_rate=sample_rate,
        labels=vocab,
        batch_size=batch_size,
        **eval_dl_params)

    N = len(data_layer)
    logger.info('Evaluating {0} examples'.format(N))

    data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
        sample_rate=sample_rate,
        **jasper_params["AudioToMelSpectrogramPreprocessor"])
    jasper_encoder = nemo_asr.JasperEncoder(
        feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
        **jasper_params["JasperEncoder"])
    jasper_decoder = nemo_asr.JasperDecoderForCTC(
        feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
        num_classes=len(vocab))
    greedy_decoder = nemo_asr.GreedyCTCDecoder()

    logger.info('================================')
    logger.info(
        f"Number of parameters in encoder: {jasper_encoder.num_weights}")
    logger.info(
        f"Number of parameters in decoder: {jasper_decoder.num_weights}")
    logger.info(
        f"Total number of parameters: "
        f"{jasper_decoder.num_weights + jasper_encoder.num_weights}")
    logger.info('================================')

    # Define inference DAG
    audio_signal_e1, a_sig_length_e1, transcript_e1, transcript_len_e1 =\
        data_layer()
    processed_signal_e1, p_length_e1 = data_preprocessor(
        input_signal=audio_signal_e1,
        length=a_sig_length_e1)
    encoded_e1, encoded_len_e1 = jasper_encoder(
        audio_signal=processed_signal_e1,
        length=p_length_e1)
    log_probs_e1 = jasper_decoder(encoder_output=encoded_e1)
    predictions_e1 = greedy_decoder(log_probs=log_probs_e1)

    eval_tensors = [log_probs_e1, predictions_e1,
                    transcript_e1, transcript_len_e1, encoded_len_e1]

    # inference
    evaluated_tensors = neural_factory.infer(
        tensors=eval_tensors,
        checkpoint_dir=load_dir,
        cache=True
    )

    greedy_hypotheses = post_process_predictions(evaluated_tensors[1], vocab)
    references = post_process_transcripts(
        evaluated_tensors[2], evaluated_tensors[3], vocab)

    wer = word_error_rate(hypotheses=greedy_hypotheses, references=references)
    logger.info("Greedy WER {:.2f}%".format(wer*100))

    # language model
    if args.lm_path:
        if args.alpha_max is None:
            args.alpha_max = args.alpha
        # include alpha_max in tuning range
        args.alpha_max += args.alpha_step/10.0

        if args.beta_max is None:
            args.beta_max = args.beta
        # include beta_max in tuning range
        args.beta_max += args.beta_step/10.0

        beam_wers = []

        for alpha in np.arange(args.alpha, args.alpha_max, args.alpha_step):
            for beta in np.arange(args.beta, args.beta_max, args.beta_step):
                logger.info('================================')
                logger.info(f'Infering with (alpha, beta): ({alpha}, {beta})')
                beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM(
                    vocab=vocab,
                    beam_width=args.beam_width,
                    alpha=alpha,
                    beta=beta,
                    lm_path=args.lm_path,
                    num_cpus=max(os.cpu_count(), 1))
                beam_predictions_e1 = beam_search_with_lm(
                    log_probs=log_probs_e1, log_probs_length=encoded_len_e1)

                evaluated_tensors = neural_factory.infer(
                    tensors=[beam_predictions_e1],
                    use_cache=True,
                    verbose=False
                )

                beam_hypotheses = []
                # Over mini-batch
                for i in evaluated_tensors[-1]:
                    # Over samples
                    for j in i:
                        beam_hypotheses.append(j[0][1])
                lm_wer = word_error_rate(
                    hypotheses=beam_hypotheses, references=references)
                logger.info("Beam WER {:.2f}%".format(lm_wer*100))
                beam_wers.append(((alpha, beta), lm_wer*100))

        logger.info('Beam WER for (alpha, beta)')
        logger.info('================================')
        logger.info('\n' + '\n'.join([str(e) for e in beam_wers]))
        logger.info('================================')
        best_beam_wer = min(beam_wers, key=lambda x: x[1])
        logger.info('Best (alpha, beta): '
                    f'{best_beam_wer[0]}, '
                    f'WER: {best_beam_wer[1]:.2f}%')


    # save results
    if args.save_results:
        selected_dataset = args.eval_datasets
        results = {
          "model_id": args.model_id,
          "dataset": selected_dataset,
          "wer": wer,
          "transcript": ' '.join(greedy_hypotheses),
          "gtruth": ' '.join(references)
        }
        if args.lm_path:
            results['alpha-beta'] = best_beam_wer[0]
            results['beam transcript'] = ' '.join(beam_hypotheses)
            results['lm_wer'] = best_beam_wer[1]/100
        else:
            results['lm_wer'] = None

        mkdir_p(args.save_results)

        dataset_name = selected_dataset.split("/")[-1].split(".")[0]
        model_name = args.model_id
        inf_type = "lm" if args.lm_path else "am"
        if args.alpha_step:
          inf_type = inf_type + "_grid"
        filename = os.path.join(args.save_results,
                                "results-" + inf_type + "__" \
                                + dataset_name + "__" + model_name + ".json")
        logger.info("Saving inference results to {}".format(filename))
        with open(filename, "w") as out_file:
            json.dump(results, out_file)
コード例 #4
0
ファイル: garnet.py プロジェクト: yonefx/NeMo
def create_dag(args, cfg, logger, num_gpus):

    # Defining nodes
    data = nemo_asr.AudioToTextDataLayer(
        manifest_filepath=args.train_dataset,
        labels=cfg['target']['labels'],
        batch_size=cfg['optimization']['batch_size'],
        eos_id=cfg['target']['eos_id'],
        **cfg['AudioToTextDataLayer']['train']
    )
    data_evals = []
    if args.eval_datasets:
        for val_path in args.eval_datasets:
            data_evals.append(nemo_asr.AudioToTextDataLayer(
                manifest_filepath=val_path,
                labels=cfg['target']['labels'],
                batch_size=cfg['inference']['batch_size'],
                eos_id=cfg['target']['eos_id'],
                **cfg['AudioToTextDataLayer']['eval']
            ))
    else:
        logger.info("There were no val datasets passed")
    data_preprocessor = nemo_asr.AudioPreprocessing(
        **cfg['AudioPreprocessing']
    )
    data_augmentation = nemo_asr.SpectrogramAugmentation(
        **cfg['SpectrogramAugmentation']
    )
    encoder = nemo_asr.JasperEncoder(
        feat_in=cfg["AudioPreprocessing"]["features"],
        **cfg['JasperEncoder']
    )
    if args.encoder_checkpoint is not None \
            and os.path.exists(args.encoder_checkpoint):
        if cfg['JasperEncoder']['load']:
            encoder.restore_from(args.encoder_checkpoint, args.local_rank)
            logger.info(f'Loaded weights for encoder'
                        f' from {args.encoder_checkpoint}')
        if cfg['JasperEncoder']['freeze']:
            encoder.freeze()
            logger.info(f'Freeze encoder weights')
    connector = nemo_asr.JasperRNNConnector(
        in_channels=cfg['JasperEncoder']['jasper'][-1]['filters'],
        out_channels=cfg['DecoderRNN']['hidden_size']
    )
    decoder = nemo.backends.pytorch.DecoderRNN(
        voc_size=len(cfg['target']['labels']),
        bos_id=cfg['target']['bos_id'],
        **cfg['DecoderRNN']
    )
    if args.decoder_checkpoint is not None \
            and os.path.exists(args.decoder_checkpoint):
        if cfg['DecoderRNN']['load']:
            decoder.restore_from(args.decoder_checkpoint, args.local_rank)
            logger.info(f'Loaded weights for decoder'
                        f' from {args.decoder_checkpoint}')
        if cfg['DecoderRNN']['freeze']:
            decoder.freeze()
            logger.info(f'Freeze decoder weights')
            if cfg['decoder']['unfreeze_attn']:
                for name, param in decoder.attention.named_parameters():
                    param.requires_grad = True
                logger.info(f'Unfreeze decoder attn weights')
    num_data = len(data)
    batch_size = cfg['optimization']['batch_size']
    num_epochs = cfg['optimization']['params']['num_epochs']
    steps_per_epoch = int(num_data / (batch_size * num_gpus))
    total_steps = num_epochs * steps_per_epoch
    vsc = ValueSetterCallback
    tf_callback = ValueSetterCallback(
        decoder, 'teacher_forcing',
        policies=[
            vsc.Policy(vsc.Method.Const(1.0), start=0.0, end=1.0)
        ],
        total_steps=total_steps
    )
    seq_loss = nemo.backends.pytorch.SequenceLoss(
        pad_id=cfg['target']['pad_id'],
        smoothing_coef=cfg['optimization']['smoothing_coef'],
        sample_wise=cfg['optimization']['sample_wise']
    )
    se_callback = ValueSetterCallback(
        seq_loss, 'smoothing_coef',
        policies=[
            vsc.Policy(
                vsc.Method.Const(seq_loss.smoothing_coef),
                start=0.0, end=1.0
            ),
        ],
        total_steps=total_steps
    )
    beam_search = nemo.backends.pytorch.BeamSearch(
        decoder=decoder,
        pad_id=cfg['target']['pad_id'],
        bos_id=cfg['target']['bos_id'],
        eos_id=cfg['target']['eos_id'],
        max_len=cfg['target']['max_len'],
        beam_size=cfg['inference']['beam_size']
    )
    uf_callback = UnfreezeCallback(
        [encoder, decoder],
        start_epoch=cfg['optimization']['start_unfreeze']
    )
    saver_callback = nemo.core.ModuleSaverCallback(
        save_modules_list=[encoder, connector, decoder],
        folder=args.checkpoint_dir,
        step_freq=args.eval_freq
    )

    # Creating DAG
    audios, audio_lens, transcripts, _ = data()
    processed_audios, processed_audio_lens = data_preprocessor(
        input_signal=audios,
        length=audio_lens
    )
    augmented_spec = data_augmentation(input_spec=processed_audios)
    encoded, _ = encoder(
        audio_signal=augmented_spec,
        length=processed_audio_lens
    )
    encoded = connector(tensor=encoded)
    log_probs, _ = decoder(
        targets=transcripts,
        encoder_outputs=encoded
    )
    train_loss = seq_loss(
        log_probs=log_probs,
        targets=transcripts
    )
    evals = []
    for i, data_eval in enumerate(data_evals):
        audios, audio_lens, transcripts, _ = data_eval()
        processed_audios, processed_audio_lens = data_preprocessor(
            input_signal=audios,
            length=audio_lens
        )
        encoded, _ = encoder(
            audio_signal=processed_audios,
            length=processed_audio_lens
        )
        encoded = connector(tensor=encoded)
        log_probs, _ = decoder(
            targets=transcripts,
            encoder_outputs=encoded
        )
        loss = seq_loss(
            log_probs=log_probs,
            targets=transcripts
        )
        predictions, aw = beam_search(encoder_outputs=encoded)
        evals.append((args.eval_datasets[i],
                     (loss, log_probs, transcripts, predictions, aw)))

    # Update config
    cfg['num_params'] = {
        'encoder': encoder.num_weights,
        'connector': connector.num_weights,
        'decoder': decoder.num_weights
    }
    cfg['num_params']['total'] = sum(cfg['num_params'].values())
    cfg['input']['train'] = {'num_data': num_data}
    cfg['optimization']['steps_per_epoch'] = steps_per_epoch
    cfg['optimization']['total_steps'] = total_steps

    return (train_loss, evals), cfg, [tf_callback, se_callback,
                                      uf_callback, saver_callback]
コード例 #5
0
def main():
    parser = argparse.ArgumentParser(description='Jasper')
    parser.add_argument("--local_rank", default=None, type=int)
    parser.add_argument("--batch_size", default=64, type=int)
    parser.add_argument("--model_config", type=str, required=True)
    parser.add_argument("--eval_datasets", type=str, required=True)
    parser.add_argument("--load_dir", type=str, required=True)
    parser.add_argument("--save_logprob", default=None, type=str)
    parser.add_argument("--lm_path", default=None, type=str)
    parser.add_argument('--alpha',
                        default=2.,
                        type=float,
                        help='value of LM weight',
                        required=False)
    parser.add_argument(
        '--alpha_max',
        type=float,
        help='maximum value of LM weight (for a grid search in \'eval\' mode)',
        required=False)
    parser.add_argument('--alpha_step',
                        type=float,
                        help='step for LM weight\'s tuning in \'eval\' mode',
                        required=False,
                        default=0.1)
    parser.add_argument('--beta',
                        default=1.5,
                        type=float,
                        help='value of word count weight',
                        required=False)
    parser.add_argument(
        '--beta_max',
        type=float,
        help='maximum value of word count weight (for a grid search in \
          \'eval\' mode',
        required=False)
    parser.add_argument(
        '--beta_step',
        type=float,
        help='step for word count weight\'s tuning in \'eval\' mode',
        required=False,
        default=0.1)
    parser.add_argument("--beam_width", default=128, type=int)

    args = parser.parse_args()
    batch_size = args.batch_size
    load_dir = args.load_dir

    if args.local_rank is not None:
        if args.lm_path:
            raise NotImplementedError(
                "Beam search decoder with LM does not currently support "
                "evaluation on multi-gpu.")
        device = nemo.core.DeviceType.AllGpu
    else:
        device = nemo.core.DeviceType.GPU

    # Instantiate Neural Factory with supported backend
    neural_factory = nemo.core.NeuralModuleFactory(
        backend=nemo.core.Backend.PyTorch,
        local_rank=args.local_rank,
        optimization_level=nemo.core.Optimization.mxprO1,
        placement=device)
    logger = neural_factory.logger

    if args.local_rank is not None:
        logger.info('Doing ALL GPU')

    yaml = YAML(typ="safe")
    with open(args.model_config) as f:
        jasper_params = yaml.load(f)
    vocab = jasper_params['labels']
    sample_rate = jasper_params['sample_rate']

    eval_datasets = args.eval_datasets

    eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
    eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
    del eval_dl_params["train"]
    del eval_dl_params["eval"]
    data_layer = nemo_asr.AudioToTextDataLayer(manifest_filepath=eval_datasets,
                                               sample_rate=sample_rate,
                                               labels=vocab,
                                               batch_size=batch_size,
                                               **eval_dl_params)

    N = len(data_layer)
    logger.info('Evaluating {0} examples'.format(N))

    data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
        sample_rate=sample_rate,
        **jasper_params["AudioToMelSpectrogramPreprocessor"])
    jasper_encoder = nemo_asr.JasperEncoder(
        feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
        **jasper_params["JasperEncoder"])
    jasper_decoder = nemo_asr.JasperDecoderForCTC(
        feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
        num_classes=len(vocab))
    greedy_decoder = nemo_asr.GreedyCTCDecoder()

    logger.info('================================')
    logger.info(
        f"Number of parameters in encoder: {jasper_encoder.num_weights}")
    logger.info(
        f"Number of parameters in decoder: {jasper_decoder.num_weights}")
    logger.info(f"Total number of parameters in model: "
                f"{jasper_decoder.num_weights + jasper_encoder.num_weights}")
    logger.info('================================')

    audio_signal_e1, a_sig_length_e1, transcript_e1, transcript_len_e1 =\
        data_layer()
    processed_signal_e1, p_length_e1 = data_preprocessor(
        input_signal=audio_signal_e1, length=a_sig_length_e1)
    encoded_e1, encoded_len_e1 = jasper_encoder(
        audio_signal=processed_signal_e1, length=p_length_e1)
    log_probs_e1 = jasper_decoder(encoder_output=encoded_e1)
    predictions_e1 = greedy_decoder(log_probs=log_probs_e1)

    eval_tensors = [
        log_probs_e1, predictions_e1, transcript_e1, transcript_len_e1,
        encoded_len_e1
    ]

    evaluated_tensors = neural_factory.infer(tensors=eval_tensors,
                                             checkpoint_dir=load_dir,
                                             cache=True)

    greedy_hypotheses = post_process_predictions(evaluated_tensors[1], vocab)
    references = post_process_transcripts(evaluated_tensors[2],
                                          evaluated_tensors[3], vocab)
    wer = word_error_rate(hypotheses=greedy_hypotheses, references=references)
    logger.info("Greedy WER {:.2f}%".format(wer * 100))

    if args.lm_path:
        if args.alpha_max is None:
            args.alpha_max = args.alpha
        # include alpha_max in tuning range
        args.alpha_max += args.alpha_step / 10.0

        if args.beta_max is None:
            args.beta_max = args.beta
        # include beta_max in tuning range
        args.beta_max += args.beta_step / 10.0

        beam_wers = []

        for alpha in np.arange(args.alpha, args.alpha_max, args.alpha_step):
            for beta in np.arange(args.beta, args.beta_max, args.beta_step):
                logger.info('================================')
                logger.info(f'Infering with (alpha, beta): ({alpha}, {beta})')
                beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM(
                    vocab=vocab,
                    beam_width=args.beam_width,
                    alpha=alpha,
                    beta=beta,
                    lm_path=args.lm_path,
                    num_cpus=max(os.cpu_count(), 1))
                beam_predictions_e1 = beam_search_with_lm(
                    log_probs=log_probs_e1, log_probs_length=encoded_len_e1)

                evaluated_tensors = neural_factory.infer(
                    tensors=[beam_predictions_e1],
                    use_cache=True,
                    verbose=False)

                beam_hypotheses = []
                # Over mini-batch
                for i in evaluated_tensors[-1]:
                    # Over samples
                    for j in i:
                        beam_hypotheses.append(j[0][1])

                wer = word_error_rate(hypotheses=beam_hypotheses,
                                      references=references)
                logger.info("Beam WER {:.2f}%".format(wer * 100))
                beam_wers.append(((alpha, beta), wer * 100))

        logger.info('Beam WER for (alpha, beta)')
        logger.info('================================')
        logger.info('\n' + '\n'.join([str(e) for e in beam_wers]))
        logger.info('================================')
        best_beam_wer = min(beam_wers, key=lambda x: x[1])
        logger.info('Best (alpha, beta): '
                    f'{best_beam_wer[0]}, '
                    f'WER: {best_beam_wer[1]:.2f}%')

    if args.save_logprob:
        # Convert logits to list of numpy arrays
        logprob = []
        for i, batch in enumerate(evaluated_tensors[0]):
            for j in range(batch.shape[0]):
                logprob.append(
                    batch[j][:evaluated_tensors[4][i][j], :].cpu().numpy())
        with open(args.save_logprob, 'wb') as f:
            pickle.dump(logprob, f, protocol=pickle.HIGHEST_PROTOCOL)
コード例 #6
0
def create_all_dags(args, neural_factory):
    logger = neural_factory.logger
    yaml = YAML(typ="safe")
    with open(args.model_config) as f:
        jasper_params = yaml.load(f)
    vocab = jasper_params['labels']
    sample_rate = jasper_params['sample_rate']

    # Calculate num_workers for dataloader
    total_cpus = os.cpu_count()
    cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)

    # perturb_config = jasper_params.get('perturb', None)
    train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
    train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
    del train_dl_params["train"]
    del train_dl_params["eval"]
    # del train_dl_params["normalize_transcripts"]

    data_layer = nemo_asr.AudioToTextDataLayer(
        manifest_filepath=args.train_dataset,
        sample_rate=sample_rate,
        labels=vocab,
        batch_size=args.batch_size,
        num_workers=cpu_per_traindl,
        **train_dl_params,
        # normalize_transcripts=False
    )

    N = len(data_layer)
    steps_per_epoch = math.ceil(
        N / (args.batch_size * args.iter_per_step * args.num_gpus))
    logger.info('Have {0} examples to train on.'.format(N))

    data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
        sample_rate=sample_rate,
        **jasper_params["AudioToMelSpectrogramPreprocessor"])

    multiply_batch_config = jasper_params.get('MultiplyBatch', None)
    if multiply_batch_config:
        multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)

    spectr_augment_config = jasper_params.get('SpectrogramAugmentation', None)
    if spectr_augment_config:
        data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
            **spectr_augment_config)

    eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
    eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
    del eval_dl_params["train"]
    del eval_dl_params["eval"]
    data_layers_eval = []

    if args.eval_datasets:
        for eval_datasets in args.eval_datasets:
            data_layer_eval = nemo_asr.AudioToTextDataLayer(
                manifest_filepath=eval_datasets,
                sample_rate=sample_rate,
                labels=vocab,
                batch_size=args.eval_batch_size,
                num_workers=cpu_per_traindl,
                **eval_dl_params,
            )

            data_layers_eval.append(data_layer_eval)
    else:
        neural_factory.logger.info("There were no val datasets passed")

    jasper_encoder = nemo_asr.JasperEncoder(
        feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
        **jasper_params["JasperEncoder"])

    jasper_decoder = nemo_asr.JasperDecoderForCTC(
        feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
        num_classes=len(vocab),
        factory=neural_factory)

    ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))

    greedy_decoder = nemo_asr.GreedyCTCDecoder()

    logger.info('================================')
    logger.info(
        f"Number of parameters in encoder: {jasper_encoder.num_weights}")
    logger.info(
        f"Number of parameters in decoder: {jasper_decoder.num_weights}")
    logger.info(f"Total number of parameters in model: "
                f"{jasper_decoder.num_weights + jasper_encoder.num_weights}")
    logger.info('================================')

    # Train DAG
    audio_signal_t, a_sig_length_t, \
        transcript_t, transcript_len_t = data_layer()
    processed_signal_t, p_length_t = data_preprocessor(
        input_signal=audio_signal_t, length=a_sig_length_t)

    if multiply_batch_config:
        processed_signal_t, p_length_t, transcript_t, transcript_len_t = \
            multiply_batch(
                in_x=processed_signal_t, in_x_len=p_length_t,
                in_y=transcript_t,
                in_y_len=transcript_len_t)

    if spectr_augment_config:
        processed_signal_t = data_spectr_augmentation(
            input_spec=processed_signal_t)

    encoded_t, encoded_len_t = jasper_encoder(audio_signal=processed_signal_t,
                                              length=p_length_t)
    log_probs_t = jasper_decoder(encoder_output=encoded_t)
    predictions_t = greedy_decoder(log_probs=log_probs_t)
    loss_t = ctc_loss(log_probs=log_probs_t,
                      targets=transcript_t,
                      input_length=encoded_len_t,
                      target_length=transcript_len_t)

    # Callbacks needed to print info to console and Tensorboard
    train_callback = nemo.core.SimpleLossLoggerCallback(
        tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
        print_func=partial(monitor_asr_train_progress,
                           labels=vocab,
                           logger=logger),
        get_tb_values=lambda x: [("loss", x[0])],
        tb_writer=neural_factory.tb_writer,
    )

    chpt_callback = nemo.core.CheckpointCallback(
        folder=neural_factory.checkpoint_dir,
        load_from_folder=args.load_dir,
        step_freq=args.checkpoint_save_freq)

    callbacks = [train_callback, chpt_callback]

    # assemble eval DAGs
    for i, eval_dl in enumerate(data_layers_eval):
        audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e = \
            eval_dl()
        processed_signal_e, p_length_e = data_preprocessor(
            input_signal=audio_signal_e, length=a_sig_length_e)
        encoded_e, encoded_len_e = jasper_encoder(
            audio_signal=processed_signal_e, length=p_length_e)
        log_probs_e = jasper_decoder(encoder_output=encoded_e)
        predictions_e = greedy_decoder(log_probs=log_probs_e)
        loss_e = ctc_loss(log_probs=log_probs_e,
                          targets=transcript_e,
                          input_length=encoded_len_e,
                          target_length=transcript_len_e)

        # create corresponding eval callback
        tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
        eval_callback = nemo.core.EvaluatorCallback(
            eval_tensors=[
                loss_e, predictions_e, transcript_e, transcript_len_e
            ],
            user_iter_callback=partial(process_evaluation_batch, labels=vocab),
            user_epochs_done_callback=partial(process_evaluation_epoch,
                                              tag=tagname,
                                              logger=logger),
            eval_step=args.eval_freq,
            tb_writer=neural_factory.tb_writer)

        callbacks.append(eval_callback)
    return loss_t, callbacks, steps_per_epoch
コード例 #7
0
def offline_inference(config, encoder, decoder, audio_file):
  MODEL_YAML = config
  CHECKPOINT_ENCODER = encoder
  CHECKPOINT_DECODER = decoder
  sample_rate, signal = wave.read(audio_file)

  # get labels (vocab)
  yaml = YAML(typ="safe")
  with open(MODEL_YAML) as f:
    jasper_model_definition = yaml.load(f)
  labels = jasper_model_definition['labels']

  # build neural factory and neural modules
  neural_factory = nemo.core.NeuralModuleFactory(
    placement=nemo.core.DeviceType.GPU,
    backend=nemo.core.Backend.PyTorch)
  data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
    factory=neural_factory,
    **jasper_model_definition["AudioToMelSpectrogramPreprocessor"])

  jasper_encoder = nemo_asr.JasperEncoder(
    feat_in=jasper_model_definition["AudioToMelSpectrogramPreprocessor"]["features"],
    **jasper_model_definition["JasperEncoder"])

  jasper_decoder = nemo_asr.JasperDecoderForCTC(
    feat_in=jasper_model_definition["JasperEncoder"]["jasper"][-1]["filters"],
    num_classes=len(labels))

  greedy_decoder = nemo_asr.GreedyCTCDecoder()

  # load model
  jasper_encoder.restore_from(CHECKPOINT_ENCODER)
  jasper_decoder.restore_from(CHECKPOINT_DECODER)

  # AudioDataLayer
  class AudioDataLayer(DataLayerNM):
    @staticmethod
    def create_ports():
      input_ports = {}
      output_ports = {
        "audio_signal": NeuralType({0: AxisType(BatchTag),
                                    1: AxisType(TimeTag)}),

        "a_sig_length": NeuralType({0: AxisType(BatchTag)}),
      }
      return input_ports, output_ports

    def __init__(self, **kwargs):
      DataLayerNM.__init__(self, **kwargs)
      self.output_enable = False

    def __iter__(self):
      return self

    def __next__(self):
      if not self.output_enable:
        raise StopIteration
      self.output_enable = False
      return torch.as_tensor(self.signal, dtype=torch.float32), \
            torch.as_tensor(self.signal_shape, dtype=torch.int64)

    def set_signal(self, signal):
      self.signal = np.reshape(signal.astype(np.float32)/32768., [1, -1])
      self.signal_shape = np.expand_dims(self.signal.size, 0).astype(np.int64)
      self.output_enable = True

    def __len__(self):
      return 1

    @property
    def dataset(self):
      return None

    @property
    def data_iterator(self):
      return self

  # Instantiate necessary neural modules
  data_layer = AudioDataLayer()

  # Define inference DAG
  audio_signal, audio_signal_len = data_layer()
  processed_signal, processed_signal_len = data_preprocessor(
    input_signal=audio_signal,
    length=audio_signal_len)
  encoded, encoded_len = jasper_encoder(audio_signal=processed_signal,
                                        length=processed_signal_len)
  log_probs = jasper_decoder(encoder_output=encoded)
  predictions = greedy_decoder(log_probs=log_probs)

  # audio inference
  data_layer.set_signal(signal)

  tensors = neural_factory.infer([
    audio_signal,
    processed_signal,
    encoded,
    log_probs,
    predictions], verbose=False)

  # results
  audio = tensors[0][0][0].cpu().numpy()
  features = tensors[1][0][0].cpu().numpy()
  encoded_features = tensors[2][0][0].cpu().numpy(),
  probs = tensors[3][0][0].cpu().numpy()
  preds = tensors[4][0]
  transcript = post_process_predictions([preds], labels)

  return transcript, audio, features, encoded_features, probs, preds
コード例 #8
0
ファイル: views.py プロジェクト: aymanehachcham/ASR_TTS
def convert(request):
    """
    ** Create new sound recognation object by convert audio to text .


    ** Use Case Exemple of Post:

                {

                    "audio":"base64 format",

                }



    """
    import json
    from ruamel.yaml import YAML
    import nemo
    import nemo_asr
    import IPython.display as ipd
    MODEL_YAML = "/home/docker/app/ai_models/quartznet15x5.yaml"
    CHECKPOINT_ENCODER = "/home/docker/app/ai_models/JasperEncoder-STEP-243800.pt"
    CHECKPOINT_DECODER = "/home/docker/app/ai_models/JasperDecoderForCTC-STEP-243800.pt"
    ENABLE_NGRAM = False
    yaml = YAML(typ="safe")
    with open(MODEL_YAML) as f:
        jasper_model_definition = yaml.load(f)
    labels = jasper_model_definition['labels']
    neural_factory = nemo.core.NeuralModuleFactory(
        placement=nemo.core.DeviceType.CPU, backend=nemo.core.Backend.PyTorch)
    data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
        factory=neural_factory)
    jasper_encoder = nemo_asr.JasperEncoder(
        jasper=jasper_model_definition['JasperEncoder']['jasper'],
        activation=jasper_model_definition['JasperEncoder']['activation'],
        feat_in=jasper_model_definition['AudioToMelSpectrogramPreprocessor']
        ['features'])
    jasper_encoder.restore_from(CHECKPOINT_ENCODER, local_rank=0)
    jasper_decoder = nemo_asr.JasperDecoderForCTC(feat_in=1024,
                                                  num_classes=len(labels))
    jasper_decoder.restore_from(CHECKPOINT_DECODER, local_rank=0)
    greedy_decoder = nemo_asr.GreedyCTCDecoder()

    def wav_to_text(manifest, greedy=True):
        from ruamel.yaml import YAML
        yaml = YAML(typ="safe")
        with open(MODEL_YAML) as f:
            jasper_model_definition = yaml.load(f)
        labels = jasper_model_definition['labels']
        data_layer = nemo_asr.AudioToTextDataLayer(shuffle=False,
                                                   manifest_filepath=manifest,
                                                   labels=labels,
                                                   batch_size=1)
        audio_signal, audio_signal_len, _, _ = data_layer()
        processed_signal, processed_signal_len = data_preprocessor(
            input_signal=audio_signal, length=audio_signal_len)
        encoded, encoded_len = jasper_encoder(audio_signal=processed_signal,
                                              length=processed_signal_len)
        log_probs = jasper_decoder(encoder_output=encoded)
        predictions = greedy_decoder(log_probs=log_probs)

        if ENABLE_NGRAM:
            print('Running with beam search')
            beam_predictions = beam_search_with_lm(
                log_probs=log_probs, log_probs_length=encoded_len)
            eval_tensors = [beam_predictions]

        if greedy:
            eval_tensors = [predictions]

        tensors = neural_factory.infer(tensors=eval_tensors)
        if greedy:
            from nemo_asr.helpers import post_process_predictions
            prediction = post_process_predictions(tensors[0], labels)
        else:
            prediction = tensors[0][0][0][0][1]
        return prediction

    def create_manifest(file_path):
        # create manifest
        manifest = dict()
        manifest['audio_filepath'] = file_path
        manifest['duration'] = 18000
        manifest['text'] = 'todo'
        with open(file_path + ".json", 'w') as fout:
            fout.write(json.dumps(manifest))
        return file_path + ".json"

    data = request.FILES['audio']
    path = "media/" + data.name
    audio = Song.objects.create(audio_file=data)
    transcription = wav_to_text(create_manifest(audio.audio_file.path))

    return Response({'Output': transcription})
def main(config_file,
         nn_encoder,
         nn_decoder,
         nn_onnx_encoder,
         nn_onnx_decoder,
         pre_v09_model=False,
         batch_size=1,
         time_steps=256):
    yaml = YAML(typ="safe")

    print("Loading config file...")
    with open(config_file) as f:
        jasper_model_definition = yaml.load(f)

    print("Determining model shape...")
    if 'AudioPreprocessing' in jasper_model_definition:
        num_encoder_input_features = \
            jasper_model_definition['AudioPreprocessing']['features']
    elif 'AudioToMelSpectrogramPreprocessor' in jasper_model_definition:
        num_encoder_input_features = \
            jasper_model_definition['AudioToMelSpectrogramPreprocessor'][
                'features']
    else:
        num_encoder_input_features = 64
    num_decoder_input_features = \
        jasper_model_definition['JasperEncoder']['jasper'][-1]['filters']
    print(
        "  Num encoder input features: {}".format(num_encoder_input_features))
    print(
        "  Num decoder input features: {}".format(num_decoder_input_features))

    print("Initializing models...")
    jasper_encoder = nemo_asr.JasperEncoder(
        feat_in=num_encoder_input_features,
        **jasper_model_definition['JasperEncoder'])
    jasper_decoder = nemo_asr.JasperDecoderForCTC(
        feat_in=num_decoder_input_features,
        num_classes=len(jasper_model_definition['labels']))

    # This is necessary if you are using checkpoints trained with NeMo
    # version before 0.9
    print("Loading checkpoints...")
    if pre_v09_model:
        print("  Converting pre v0.9 checkpoint...")
        ckpt = torch.load(nn_encoder)
        new_ckpt = {}
        for k, v in ckpt.items():
            new_k = k.replace('.conv.', '.mconv.')
            if len(v.shape) == 3:
                new_k = new_k.replace('.weight', '.conv.weight')
            new_ckpt[new_k] = v
        jasper_encoder.load_state_dict(new_ckpt)
    else:
        jasper_encoder.restore_from(nn_encoder)
    jasper_decoder.restore_from(nn_decoder)

    nf = nemo.core.NeuralModuleFactory(create_tb_writer=False)
    print("Exporting encoder...")
    nf.deployment_export(
        jasper_encoder, nn_onnx_encoder,
        nemo.core.neural_factory.DeploymentFormat.ONNX,
        torch.zeros(batch_size,
                    num_encoder_input_features,
                    time_steps,
                    dtype=torch.float,
                    device="cuda:0"))
    print("Exporting decoder...")
    nf.deployment_export(jasper_decoder, nn_onnx_decoder,
                         nemo.core.neural_factory.DeploymentFormat.ONNX,
                         (torch.zeros(batch_size,
                                      num_decoder_input_features,
                                      time_steps // 2,
                                      dtype=torch.float,
                                      device="cuda:0")))
    print("Export completed successfully.")
コード例 #10
0
def create_all_dags(args, neural_factory):
    '''
    creates train and eval dags as well as their callbacks
    returns train loss tensor and callbacks'''

    # parse the config files
    yaml = YAML(typ="safe")
    with open(args.model_config) as f:
        quartz_params = yaml.load(f)

    vocab = quartz_params['labels']
    sample_rate = quartz_params['sample_rate']

    # Calculate num_workers for dataloader
    total_cpus = os.cpu_count()
    cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)

    # create data layer for training
    train_dl_params = copy.deepcopy(quartz_params["AudioToTextDataLayer"])
    train_dl_params.update(quartz_params["AudioToTextDataLayer"]["train"])
    del train_dl_params["train"]
    del train_dl_params["eval"]
    # del train_dl_params["normalize_transcripts"]

    data_layer_train = nemo_asr.AudioToTextDataLayer(
        manifest_filepath=args.train_dataset,
        sample_rate=sample_rate,
        labels=vocab,
        batch_size=args.batch_size,
        num_workers=cpu_per_traindl,
        **train_dl_params,
        # normalize_transcripts=False
    )

    N = len(data_layer_train)
    steps_per_epoch = int(N / (args.batch_size * args.num_gpus))

    # create separate data layers for eval
    # we need separate eval dags for separate eval datasets
    # but all other modules in these dags will be shared

    eval_dl_params = copy.deepcopy(quartz_params["AudioToTextDataLayer"])
    eval_dl_params.update(quartz_params["AudioToTextDataLayer"]["eval"])
    del eval_dl_params["train"]
    del eval_dl_params["eval"]

    data_layers_eval = []
    if args.eval_datasets:
        for eval_dataset in args.eval_datasets:
            data_layer_eval = nemo_asr.AudioToTextDataLayer(
                manifest_filepath=eval_dataset,
                sample_rate=sample_rate,
                labels=vocab,
                batch_size=args.eval_batch_size,
                num_workers=cpu_per_traindl,
                **eval_dl_params,
            )

            data_layers_eval.append(data_layer_eval)
    else:
        neural_factory.logger.info("There were no val datasets passed")

    # create shared modules

    data_preprocessor = nemo_asr.AudioPreprocessing(
        sample_rate=sample_rate,
        **quartz_params["AudioPreprocessing"])

    # (QuartzNet uses the Jasper baseline encoder and decoder)
    encoder = nemo_asr.JasperEncoder(
        feat_in=quartz_params["AudioPreprocessing"]["features"],
        **quartz_params["JasperEncoder"])

    decoder = nemo_asr.JasperDecoderForCTC(
        feat_in=quartz_params["JasperEncoder"]["jasper"][-1]["filters"],
        num_classes=len(vocab))

    ctc_loss = nemo_asr.CTCLossNM(
        num_classes=len(vocab))

    greedy_decoder = nemo_asr.GreedyCTCDecoder()

    # create augmentation modules (only used for training) if their configs
    # are present

    multiply_batch_config = quartz_params.get('MultiplyBatch', None)
    if multiply_batch_config:
        multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)

    spectr_augment_config = quartz_params.get('SpectrogramAugmentation', None)
    if spectr_augment_config:
        data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
            **spectr_augment_config)

    # assemble train DAG

    audio_signal_t, a_sig_length_t, \
        transcript_t, transcript_len_t = data_layer_train()

    processed_signal_t, p_length_t = data_preprocessor(
        input_signal=audio_signal_t,
        length=a_sig_length_t)

    if multiply_batch_config:
        processed_signal_t, p_length_t, transcript_t, transcript_len_t = \
            multiply_batch(
                in_x=processed_signal_t, in_x_len=p_length_t,
                in_y=transcript_t,
                in_y_len=transcript_len_t)

    if spectr_augment_config:
        processed_signal_t = data_spectr_augmentation(
            input_spec=processed_signal_t)

    encoded_t, encoded_len_t = encoder(
        audio_signal=processed_signal_t,
        length=p_length_t)
    log_probs_t = decoder(encoder_output=encoded_t)
    predictions_t = greedy_decoder(log_probs=log_probs_t)
    loss_t = ctc_loss(
        log_probs=log_probs_t,
        targets=transcript_t,
        input_length=encoded_len_t,
        target_length=transcript_len_t)

    # create train callbacks
    train_callback = nemo.core.SimpleLossLoggerCallback(
        tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
        print_func=partial(
            monitor_asr_train_progress,
            labels=vocab,
            logger=neural_factory.logger),
        get_tb_values=lambda x: [["loss", x[0]]],
        tb_writer=neural_factory.tb_writer)

    callbacks = [train_callback]

    if args.checkpoint_dir or args.load_dir:
        chpt_callback = nemo.core.CheckpointCallback(
            folder=args.checkpoint_dir,
            load_from_folder=args.load_dir,
            step_freq=args.checkpoint_save_freq)

        callbacks.append(chpt_callback)

    # assemble eval DAGs
    for i, eval_dl in enumerate(data_layers_eval):

        audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e = \
            eval_dl()
        processed_signal_e, p_length_e = data_preprocessor(
            input_signal=audio_signal_e,
            length=a_sig_length_e)
        encoded_e, encoded_len_e = encoder(
            audio_signal=processed_signal_e,
            length=p_length_e)
        log_probs_e = decoder(encoder_output=encoded_e)
        predictions_e = greedy_decoder(log_probs=log_probs_e)
        loss_e = ctc_loss(
            log_probs=log_probs_e,
            targets=transcript_e,
            input_length=encoded_len_e,
            target_length=transcript_len_e)

        # create corresponding eval callback
        tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]

        eval_callback = nemo.core.EvaluatorCallback(
            eval_tensors=[loss_e, predictions_e,
                          transcript_e, transcript_len_e],
            user_iter_callback=partial(
                process_evaluation_batch,
                labels=vocab),
            user_epochs_done_callback=partial(
                process_evaluation_epoch,
                tag=tagname,
                logger=neural_factory.logger),
            eval_step=args.eval_freq,
            tb_writer=neural_factory.tb_writer)

        callbacks.append(eval_callback)

    return loss_t, callbacks, steps_per_epoch