Ejemplo n.º 1
0
def create_dag(args, cfg, logger, num_gpus):
    # Defining nodes
    data = nemo_asr.AudioToTextDataLayer(
        manifest_filepath=args.train_dataset,
        labels=cfg['target']['labels'],
        batch_size=cfg['optimization']['batch_size'],
        eos_id=cfg['target']['eos_id'],
        **cfg['AudioToTextDataLayer']['train'],
    )
    data_evals = []
    if args.eval_datasets:
        for val_path in args.eval_datasets:
            data_evals.append(
                nemo_asr.AudioToTextDataLayer(
                    manifest_filepath=val_path,
                    labels=cfg['target']['labels'],
                    batch_size=cfg['inference']['batch_size'],
                    eos_id=cfg['target']['eos_id'],
                    **cfg['AudioToTextDataLayer']['eval'],
                ))
    else:
        logger.info("There were no val datasets passed")
    data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
        **cfg['AudioToMelSpectrogramPreprocessor'])
    data_augmentation = nemo_asr.SpectrogramAugmentation(
        **cfg['SpectrogramAugmentation'])
    encoder = nemo_asr.JasperEncoder(
        feat_in=cfg["AudioToMelSpectrogramPreprocessor"]["features"],
        **cfg['JasperEncoder'],
    )
    if args.encoder_checkpoint is not None and os.path.exists(
            args.encoder_checkpoint):
        if cfg['JasperEncoder']['load']:
            encoder.restore_from(args.encoder_checkpoint, args.local_rank)
            logger.info(f'Loaded weights for encoder'
                        f' from {args.encoder_checkpoint}')
        if cfg['JasperEncoder']['freeze']:
            encoder.freeze()
            logger.info(f'Freeze encoder weights')
    connector = nemo_asr.JasperRNNConnector(
        in_channels=cfg['JasperEncoder']['jasper'][-1]['filters'],
        out_channels=cfg['DecoderRNN']['hidden_size'],
    )
    decoder = nemo.backends.pytorch.DecoderRNN(
        voc_size=len(cfg['target']['labels']),
        bos_id=cfg['target']['bos_id'],
        **cfg['DecoderRNN'],
    )
    if args.decoder_checkpoint is not None and os.path.exists(
            args.decoder_checkpoint):
        if cfg['DecoderRNN']['load']:
            decoder.restore_from(args.decoder_checkpoint, args.local_rank)
            logger.info(f'Loaded weights for decoder'
                        f' from {args.decoder_checkpoint}')
        if cfg['DecoderRNN']['freeze']:
            decoder.freeze()
            logger.info(f'Freeze decoder weights')
            if cfg['decoder']['unfreeze_attn']:
                for name, param in decoder.attention.named_parameters():
                    param.requires_grad = True
                logger.info(f'Unfreeze decoder attn weights')
    num_data = len(data)
    batch_size = cfg['optimization']['batch_size']
    num_epochs = cfg['optimization']['params']['num_epochs']
    steps_per_epoch = int(num_data / (batch_size * num_gpus))
    total_steps = num_epochs * steps_per_epoch
    vsc = ValueSetterCallback
    tf_callback = ValueSetterCallback(
        decoder,
        'teacher_forcing',
        policies=[vsc.Policy(vsc.Method.Const(1.0), start=0.0, end=1.0)],
        total_steps=total_steps,
    )
    seq_loss = nemo.backends.pytorch.SequenceLoss(
        pad_id=cfg['target']['pad_id'],
        smoothing_coef=cfg['optimization']['smoothing_coef'],
        sample_wise=cfg['optimization']['sample_wise'],
    )
    se_callback = ValueSetterCallback(
        seq_loss,
        'smoothing_coef',
        policies=[
            vsc.Policy(vsc.Method.Const(seq_loss.smoothing_coef),
                       start=0.0,
                       end=1.0),
        ],
        total_steps=total_steps,
    )
    beam_search = nemo.backends.pytorch.BeamSearch(
        decoder=decoder,
        pad_id=cfg['target']['pad_id'],
        bos_id=cfg['target']['bos_id'],
        eos_id=cfg['target']['eos_id'],
        max_len=cfg['target']['max_len'],
        beam_size=cfg['inference']['beam_size'],
    )
    uf_callback = UnfreezeCallback(
        [encoder, decoder], start_epoch=cfg['optimization']['start_unfreeze'])
    saver_callback = nemo.core.ModuleSaverCallback(
        save_modules_list=[encoder, connector, decoder],
        folder=args.checkpoint_dir,
        step_freq=args.eval_freq,
    )

    # Creating DAG
    audios, audio_lens, transcripts, _ = data()
    processed_audios, processed_audio_lens = data_preprocessor(
        input_signal=audios, length=audio_lens)
    augmented_spec = data_augmentation(input_spec=processed_audios)
    encoded, _ = encoder(audio_signal=augmented_spec,
                         length=processed_audio_lens)
    encoded = connector(tensor=encoded)
    log_probs, _ = decoder(targets=transcripts, encoder_outputs=encoded)
    train_loss = seq_loss(log_probs=log_probs, targets=transcripts)
    evals = []
    for i, data_eval in enumerate(data_evals):
        audios, audio_lens, transcripts, _ = data_eval()
        processed_audios, processed_audio_lens = data_preprocessor(
            input_signal=audios, length=audio_lens)
        encoded, _ = encoder(audio_signal=processed_audios,
                             length=processed_audio_lens)
        encoded = connector(tensor=encoded)
        log_probs, _ = decoder(targets=transcripts, encoder_outputs=encoded)
        loss = seq_loss(log_probs=log_probs, targets=transcripts)
        predictions, aw = beam_search(encoder_outputs=encoded)
        evals.append((
            args.eval_datasets[i],
            (loss, log_probs, transcripts, predictions, aw),
        ))

    # Update config
    cfg['num_params'] = {
        'encoder': encoder.num_weights,
        'connector': connector.num_weights,
        'decoder': decoder.num_weights,
    }
    cfg['num_params']['total'] = sum(cfg['num_params'].values())
    cfg['input']['train'] = {'num_data': num_data}
    cfg['optimization']['steps_per_epoch'] = steps_per_epoch
    cfg['optimization']['total_steps'] = total_steps

    return (
        (train_loss, evals),
        cfg,
        [tf_callback, se_callback, uf_callback, saver_callback],
    )
Ejemplo n.º 2
0
    def test_clas(self):
        with open('examples/asr/experimental/configs/garnet_an4.yaml') as file:
            cfg = self.yaml.load(file)
        dl = nemo_asr.AudioToTextDataLayer(
            manifest_filepath=self.manifest_filepath,
            labels=self.labels,
            batch_size=4,
        )
        pre_process_params = {
            'frame_splicing': 1,
            'features': 64,
            'window_size': 0.02,
            'n_fft': 512,
            'dither': 1e-05,
            'window': 'hann',
            'sample_rate': 16000,
            'normalize': 'per_feature',
            'window_stride': 0.01,
            'stft_conv': True,
        }
        preprocessing = nemo_asr.AudioToMelSpectrogramPreprocessor(
            **pre_process_params)
        encoder = nemo_asr.JasperEncoder(
            jasper=cfg['encoder']['jasper'],
            activation=cfg['encoder']['activation'],
            feat_in=cfg['input']['train']['features'],
        )
        connector = nemo_asr.JasperRNNConnector(
            in_channels=cfg['encoder']['jasper'][-1]['filters'],
            out_channels=cfg['decoder']['hidden_size'],
        )
        decoder = nemo.backends.pytorch.common.DecoderRNN(
            voc_size=len(self.labels),
            bos_id=0,
            hidden_size=cfg['decoder']['hidden_size'],
            attention_method=cfg['decoder']['attention_method'],
            attention_type=cfg['decoder']['attention_type'],
            in_dropout=cfg['decoder']['in_dropout'],
            gru_dropout=cfg['decoder']['gru_dropout'],
            attn_dropout=cfg['decoder']['attn_dropout'],
            teacher_forcing=cfg['decoder']['teacher_forcing'],
            curriculum_learning=cfg['decoder']['curriculum_learning'],
            rnn_type=cfg['decoder']['rnn_type'],
            n_layers=cfg['decoder']['n_layers'],
            tie_emb_out_weights=cfg['decoder']['tie_emb_out_weights'],
        )
        loss = nemo.backends.pytorch.common.SequenceLoss()

        # DAG
        audio_signal, a_sig_length, transcripts, transcript_len = dl()
        processed_signal, p_length = preprocessing(input_signal=audio_signal,
                                                   length=a_sig_length)
        encoded, encoded_len = encoder(audio_signal=processed_signal,
                                       length=p_length)
        encoded = connector(tensor=encoded)
        log_probs, _ = decoder(targets=transcripts, encoder_outputs=encoded)
        loss = loss(log_probs=log_probs, targets=transcripts)

        # Train
        callback = nemo.core.SimpleLossLoggerCallback(
            tensors=[loss],
            print_func=lambda x: logging.info(str(x[0].item())))
        # Instantiate an optimizer to perform `train` action
        optimizer = self.nf.get_trainer()
        optimizer.train(
            [loss],
            callbacks=[callback],
            optimizer="sgd",
            optimization_params={
                "num_epochs": 10,
                "lr": 0.0003
            },
        )
Ejemplo n.º 3
0
    def test_clas(self):
        with open('examples/asr/experimental/configs/garnet_an4.yaml') as file:
            cfg = self.yaml.load(file)
        dl = nemo_asr.AudioToTextDataLayer(
            featurizer_config=self.featurizer_config,
            manifest_filepath=self.manifest_filepath,
            labels=self.labels,
            batch_size=4,
        )
        pre_process_params = {
            'int_values': False,
            'frame_splicing': 1,
            'features': 64,
            'window_size': 0.02,
            'n_fft': 512,
            'dither': 1e-05,
            'window': 'hann',
            'sample_rate': 16000,
            'normalize': 'per_feature',
            'window_stride': 0.01,
            'stft_conv': True,
        }
        preprocessing = nemo_asr.AudioToMelSpectrogramPreprocessor(
            **pre_process_params)
        encoder = nemo_asr.JasperEncoder(
            jasper=cfg['encoder']['jasper'],
            activation=cfg['encoder']['activation'],
            feat_in=cfg['input']['train']['features'],
        )
        connector = nemo_asr.JasperRNNConnector(
            in_channels=cfg['encoder']['jasper'][-1]['filters'],
            out_channels=cfg['decoder']['hidden_size'],
        )
        decoder = nemo.backends.pytorch.common.DecoderRNN(
            voc_size=len(self.labels),
            bos_id=0,
            **cfg['decoder']  # fictive
        )
        loss = nemo.backends.pytorch.common.SequenceLoss()

        # DAG
        audio_signal, a_sig_length, transcripts, transcript_len = dl()
        processed_signal, p_length = preprocessing(input_signal=audio_signal,
                                                   length=a_sig_length)
        encoded, encoded_len = encoder(audio_signal=processed_signal,
                                       length=p_length)
        encoded = connector(tensor=encoded)
        log_probs, _ = decoder(targets=transcripts, encoder_outputs=encoded)
        loss = loss(log_probs=log_probs, targets=transcripts)

        # Train
        callback = nemo.core.SimpleLossLoggerCallback(
            tensors=[loss], print_func=lambda x: print(str(x[0].item())))
        # Instantiate an optimizer to perform `train` action
        neural_factory = nemo.core.NeuralModuleFactory(
            backend=nemo.core.Backend.PyTorch,
            local_rank=None,
            create_tb_writer=False,
        )
        optimizer = neural_factory.get_trainer()
        optimizer.train(
            [loss],
            callbacks=[callback],
            optimizer="sgd",
            optimization_params={
                "num_epochs": 10,
                "lr": 0.0003
            },
        )