Ejemplo n.º 1
0
def create_all_dags(args, neural_factory):
    logger = neural_factory.logger
    yaml = YAML(typ="safe")
    with open(args.model_config) as f:
        jasper_params = yaml.load(f)
    vocab = jasper_params['labels']
    sample_rate = jasper_params['sample_rate']

    # Calculate num_workers for dataloader
    total_cpus = os.cpu_count()
    cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)

    # perturb_config = jasper_params.get('perturb', None)
    train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
    train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
    del train_dl_params["train"]
    del train_dl_params["eval"]
    # del train_dl_params["normalize_transcripts"]

    data_layer = nemo_asr.AudioToTextDataLayer(
        manifest_filepath=args.train_dataset,
        sample_rate=sample_rate,
        labels=vocab,
        batch_size=args.batch_size,
        num_workers=cpu_per_traindl,
        **train_dl_params,
        # normalize_transcripts=False
    )

    N = len(data_layer)
    steps_per_epoch = int(
        N / (args.batch_size * args.iter_per_step * args.num_gpus))
    logger.info('Have {0} examples to train on.'.format(N))

    data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
        sample_rate=sample_rate,
        **jasper_params["AudioToMelSpectrogramPreprocessor"])

    multiply_batch_config = jasper_params.get('MultiplyBatch', None)
    if multiply_batch_config:
        multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)

    spectr_augment_config = jasper_params.get('SpectrogramAugmentation', None)
    if spectr_augment_config:
        data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
            **spectr_augment_config)

    eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
    eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
    del eval_dl_params["train"]
    del eval_dl_params["eval"]
    data_layers_eval = []

    if args.eval_datasets:
        for eval_datasets in args.eval_datasets:
            data_layer_eval = nemo_asr.AudioToTextDataLayer(
                manifest_filepath=eval_datasets,
                sample_rate=sample_rate,
                labels=vocab,
                batch_size=args.eval_batch_size,
                num_workers=cpu_per_traindl,
                **eval_dl_params,
            )

            data_layers_eval.append(data_layer_eval)
    else:
        neural_factory.logger.info("There were no val datasets passed")

    jasper_encoder = nemo_asr.JasperEncoder(
        feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
        **jasper_params["JasperEncoder"])

    jasper_decoder = nemo_asr.JasperDecoderForCTC(
        feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
        num_classes=len(vocab),
        factory=neural_factory)

    ctc_loss = nemo_asr.CTCLossNM(
        num_classes=len(vocab))

    greedy_decoder = nemo_asr.GreedyCTCDecoder()

    logger.info('================================')
    logger.info(
        f"Number of parameters in encoder: {jasper_encoder.num_weights}")
    logger.info(
        f"Number of parameters in decoder: {jasper_decoder.num_weights}")
    logger.info(
        f"Total number of parameters in decoder: "
        f"{jasper_decoder.num_weights + jasper_encoder.num_weights}")
    logger.info('================================')

    # Train DAG
    audio_signal_t, a_sig_length_t, \
        transcript_t, transcript_len_t = data_layer()
    processed_signal_t, p_length_t = data_preprocessor(
        input_signal=audio_signal_t,
        length=a_sig_length_t)

    if multiply_batch_config:
        processed_signal_t, p_length_t, transcript_t, transcript_len_t = \
            multiply_batch(
                in_x=processed_signal_t, in_x_len=p_length_t,
                in_y=transcript_t,
                in_y_len=transcript_len_t)

    if spectr_augment_config:
        processed_signal_t = data_spectr_augmentation(
            input_spec=processed_signal_t)

    encoded_t, encoded_len_t = jasper_encoder(
        audio_signal=processed_signal_t,
        length=p_length_t)
    log_probs_t = jasper_decoder(encoder_output=encoded_t)
    predictions_t = greedy_decoder(log_probs=log_probs_t)
    loss_t = ctc_loss(
        log_probs=log_probs_t,
        targets=transcript_t,
        input_length=encoded_len_t,
        target_length=transcript_len_t)

    # Callbacks needed to print info to console and Tensorboard
    train_callback = nemo.core.SimpleLossLoggerCallback(
        tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
        print_func=partial(
            monitor_asr_train_progress,
            labels=vocab,
            logger=logger),
        get_tb_values=lambda x: [("loss", x[0])],
        tb_writer=neural_factory.tb_writer,
    )

    chpt_callback = nemo.core.CheckpointCallback(
        folder=neural_factory.checkpoint_dir,
        step_freq=args.checkpoint_save_freq)

    callbacks = [train_callback, chpt_callback]

    # assemble eval DAGs
    for i, eval_dl in enumerate(data_layers_eval):
        audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e = \
            eval_dl()
        processed_signal_e, p_length_e = data_preprocessor(
            input_signal=audio_signal_e,
            length=a_sig_length_e)
        encoded_e, encoded_len_e = jasper_encoder(
            audio_signal=processed_signal_e,
            length=p_length_e)
        log_probs_e = jasper_decoder(encoder_output=encoded_e)
        predictions_e = greedy_decoder(log_probs=log_probs_e)
        loss_e = ctc_loss(
            log_probs=log_probs_e,
            targets=transcript_e,
            input_length=encoded_len_e,
            target_length=transcript_len_e)

        # create corresponding eval callback
        tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
        eval_callback = nemo.core.EvaluatorCallback(
            eval_tensors=[loss_e, predictions_e,
                          transcript_e, transcript_len_e],
            user_iter_callback=partial(
                process_evaluation_batch,
                labels=vocab),
            user_epochs_done_callback=partial(
                process_evaluation_epoch,
                tag=tagname,
                logger=logger),
            eval_step=args.eval_freq,
            tb_writer=neural_factory.tb_writer)

        callbacks.append(eval_callback)
    return loss_t, callbacks, steps_per_epoch
Ejemplo n.º 2
0
def create_all_dags(args, neural_factory):
    '''
    creates train and eval dags as well as their callbacks
    returns train loss tensor and callbacks'''

    # parse the config files
    yaml = YAML(typ="safe")
    with open(args.model_config) as f:
        quartz_params = yaml.load(f)

    vocab = quartz_params['labels']
    sample_rate = quartz_params['sample_rate']

    # Calculate num_workers for dataloader
    total_cpus = os.cpu_count()
    cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)

    # create data layer for training
    train_dl_params = copy.deepcopy(quartz_params["AudioToTextDataLayer"])
    train_dl_params.update(quartz_params["AudioToTextDataLayer"]["train"])
    del train_dl_params["train"]
    del train_dl_params["eval"]
    # del train_dl_params["normalize_transcripts"]

    data_layer_train = nemo_asr.AudioToTextDataLayer(
        manifest_filepath=args.train_dataset,
        sample_rate=sample_rate,
        labels=vocab,
        batch_size=args.batch_size,
        num_workers=cpu_per_traindl,
        **train_dl_params,
        # normalize_transcripts=False
    )

    N = len(data_layer_train)
    steps_per_epoch = int(
        N / (args.batch_size * args.iter_per_step * args.num_gpus))

    # create separate data layers for eval
    # we need separate eval dags for separate eval datasets
    # but all other modules in these dags will be shared

    eval_dl_params = copy.deepcopy(quartz_params["AudioToTextDataLayer"])
    eval_dl_params.update(quartz_params["AudioToTextDataLayer"]["eval"])
    del eval_dl_params["train"]
    del eval_dl_params["eval"]

    data_layers_eval = []
    if args.eval_datasets:
        for eval_dataset in args.eval_datasets:
            data_layer_eval = nemo_asr.AudioToTextDataLayer(
                manifest_filepath=eval_dataset,
                sample_rate=sample_rate,
                labels=vocab,
                batch_size=args.eval_batch_size,
                num_workers=cpu_per_traindl,
                **eval_dl_params,
            )

            data_layers_eval.append(data_layer_eval)
    else:
        nemo.logging.warning("There were no val datasets passed")

    # create shared modules

    data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
        sample_rate=sample_rate,
        **quartz_params["AudioToMelSpectrogramPreprocessor"])

    # (QuartzNet uses the Jasper baseline encoder and decoder)
    encoder = nemo_asr.JasperEncoder(
        feat_in=quartz_params["AudioToMelSpectrogramPreprocessor"]["features"],
        **quartz_params["JasperEncoder"])

    decoder = nemo_asr.JasperDecoderForCTC(
        feat_in=quartz_params["JasperEncoder"]["jasper"][-1]["filters"],
        num_classes=len(vocab))

    ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))

    greedy_decoder = nemo_asr.GreedyCTCDecoder()

    # create augmentation modules (only used for training) if their configs
    # are present

    multiply_batch_config = quartz_params.get('MultiplyBatch', None)
    if multiply_batch_config:
        multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)

    spectr_augment_config = quartz_params.get('SpectrogramAugmentation', None)
    if spectr_augment_config:
        data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
            **spectr_augment_config)

    # assemble train DAG

    audio_signal_t, a_sig_length_t, \
        transcript_t, transcript_len_t = data_layer_train()

    processed_signal_t, p_length_t = data_preprocessor(
        input_signal=audio_signal_t, length=a_sig_length_t)

    if multiply_batch_config:
        processed_signal_t, p_length_t, transcript_t, transcript_len_t = \
            multiply_batch(
                in_x=processed_signal_t, in_x_len=p_length_t,
                in_y=transcript_t,
                in_y_len=transcript_len_t)

    if spectr_augment_config:
        processed_signal_t = data_spectr_augmentation(
            input_spec=processed_signal_t)

    encoded_t, encoded_len_t = encoder(audio_signal=processed_signal_t,
                                       length=p_length_t)
    log_probs_t = decoder(encoder_output=encoded_t)
    predictions_t = greedy_decoder(log_probs=log_probs_t)
    loss_t = ctc_loss(log_probs=log_probs_t,
                      targets=transcript_t,
                      input_length=encoded_len_t,
                      target_length=transcript_len_t)

    # create train callbacks
    train_callback = nemo.core.SimpleLossLoggerCallback(
        tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
        print_func=partial(monitor_asr_train_progress, labels=vocab),
        get_tb_values=lambda x: [["loss", x[0]]],
        tb_writer=neural_factory.tb_writer)

    callbacks = [train_callback]

    if args.checkpoint_dir or args.load_dir:
        chpt_callback = nemo.core.CheckpointCallback(
            folder=args.checkpoint_dir,
            load_from_folder=args.load_dir,
            step_freq=args.checkpoint_save_freq)

        callbacks.append(chpt_callback)

    # assemble eval DAGs
    for i, eval_dl in enumerate(data_layers_eval):

        audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e = \
            eval_dl()
        processed_signal_e, p_length_e = data_preprocessor(
            input_signal=audio_signal_e, length=a_sig_length_e)
        encoded_e, encoded_len_e = encoder(audio_signal=processed_signal_e,
                                           length=p_length_e)
        log_probs_e = decoder(encoder_output=encoded_e)
        predictions_e = greedy_decoder(log_probs=log_probs_e)
        loss_e = ctc_loss(log_probs=log_probs_e,
                          targets=transcript_e,
                          input_length=encoded_len_e,
                          target_length=transcript_len_e)

        # create corresponding eval callback
        tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]

        eval_callback = nemo.core.EvaluatorCallback(
            eval_tensors=[
                loss_e, predictions_e, transcript_e, transcript_len_e
            ],
            user_iter_callback=partial(process_evaluation_batch, labels=vocab),
            user_epochs_done_callback=partial(process_evaluation_epoch,
                                              tag=tagname),
            eval_step=args.eval_freq,
            tb_writer=neural_factory.tb_writer)

        callbacks.append(eval_callback)

    return loss_t, callbacks, steps_per_epoch