Пример #1
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        self.setup_tokenizer(cfg.tokenizer)
        super().__init__(cfg=cfg, trainer=trainer)
        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file',
                                               cfg.language_model.config_file),
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
            vocab_file=self.register_artifact('tokenizer.vocab_file',
                                              cfg.tokenizer.vocab_file)
            if cfg.tokenizer is not None else None,
        )

        self.classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=cfg.token_classifier.num_classes,
            num_layers=cfg.token_classifier.num_layers,
            activation=cfg.token_classifier.activation,
            log_softmax=cfg.token_classifier.log_softmax,
            dropout=cfg.token_classifier.dropout,
            use_transformer_init=cfg.token_classifier.use_transformer_init,
        )

        self.loss = SpanningLoss()
Пример #2
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):

        if cfg.tokenizer is not None:
            self._setup_tokenizer(cfg.tokenizer)
        else:
            self.tokenizer = None

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
            vocab_file=cfg.tokenizer.get('vocab_file')
            if cfg.tokenizer is not None else None,
        )

        self.hidden_size = self.bert_model.config.hidden_size
        self.vocab_size = self.bert_model.config.vocab_size
        self.only_mlm_loss = cfg.only_mlm_loss

        self.mlm_classifier = BertPretrainingTokenClassifier(
            hidden_size=self.hidden_size,
            num_classes=self.vocab_size,
            num_layers=cfg.num_tok_classification_layers,
            activation="gelu",
            log_softmax=True,
            use_transformer_init=True,
        )

        self.mlm_loss = SmoothedCrossEntropyLoss()

        if not self.only_mlm_loss:
            self.nsp_classifier = SequenceClassifier(
                hidden_size=self.hidden_size,
                num_classes=2,
                num_layers=cfg.num_seq_classification_layers,
                log_softmax=False,
                activation="tanh",
                use_transformer_init=True,
            )

            self.nsp_loss = CrossEntropyLoss()
            self.agg_loss = AggregatorLoss(num_inputs=2)

        # # tie weights of MLM softmax layer and embedding layer of the encoder
        if (self.mlm_classifier.mlp.last_linear_layer.weight.shape !=
                self.bert_model.embeddings.word_embeddings.weight.shape):
            raise ValueError(
                "Final classification layer does not match embedding layer.")
        self.mlm_classifier.mlp.last_linear_layer.weight = self.bert_model.embeddings.word_embeddings.weight
        # create extra bias

        # setup to track metrics
        self.validation_perplexity = Perplexity(compute_on_step=False)

        self.setup_optimization(cfg.optim)
Пример #3
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """
        Initializes model to use BERT model for GLUE tasks.
        """
        self.data_dir = cfg.dataset.data_dir
        if not os.path.exists(self.data_dir):
            raise FileNotFoundError(
                "GLUE datasets not found. For more details on how to get the data, see: "
                "https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e"
            )

        if cfg.task_name not in cfg.supported_tasks:
            raise ValueError(
                f'{cfg.task_name} not in supported task. Choose from {cfg.supported_tasks}'
            )
        self.task_name = cfg.task_name

        # MNLI task has two separate dev sets: matched and mismatched
        cfg.train_ds.file_name = os.path.join(self.data_dir,
                                              cfg.train_ds.file_name)
        if self.task_name == "mnli":
            cfg.validation_ds.file_name = [
                os.path.join(self.data_dir, 'dev_matched.tsv'),
                os.path.join(self.data_dir, 'dev_mismatched.tsv'),
            ]
        else:
            cfg.validation_ds.file_name = os.path.join(
                self.data_dir, cfg.validation_ds.file_name)
        logging.info(
            f'Using {cfg.validation_ds.file_name} for model evaluation.')
        self._setup_tokenizer(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)

        num_labels = GLUE_TASKS_NUM_LABELS[self.task_name]

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        # uses [CLS] token for classification (the first token)
        if self.task_name == "sts-b":
            self.pooler = SequenceRegression(
                hidden_size=self.bert_model.config.hidden_size)
            self.loss = MSELoss()
        else:
            self.pooler = SequenceClassifier(
                hidden_size=self.bert_model.config.hidden_size,
                num_classes=num_labels,
                log_softmax=False)
            self.loss = CrossEntropyLoss()

        # Optimizer setup needs to happen after all model weights are ready
        self.setup_optimization(cfg.optim)
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """ Initializes BERT Joint Intent and Slot model.
        """

        self.data_dir = cfg.data_dir
        self.max_seq_length = cfg.language_model.max_seq_length

        self.data_desc = IntentSlotDataDesc(
            data_dir=cfg.data_dir,
            modes=[cfg.train_ds.prefix, cfg.validation_ds.prefix])

        self._setup_tokenizer(cfg.tokenizer)
        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        # initialize Bert model

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.classifier = SequenceTokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_intents=self.data_desc.num_intents,
            num_slots=self.data_desc.num_slots,
            dropout=cfg.head.fc_dropout,
            num_layers=cfg.head.num_output_layers,
            log_softmax=False,
        )

        # define losses
        if cfg.class_balancing == 'weighted_loss':
            # You may need to increase the number of epochs for convergence when using weighted_loss
            self.intent_loss = CrossEntropyLoss(
                logits_ndim=2, weight=self.data_desc.intent_weights)
            self.slot_loss = CrossEntropyLoss(
                logits_ndim=3, weight=self.data_desc.slot_weights)
        else:
            self.intent_loss = CrossEntropyLoss(logits_ndim=2)
            self.slot_loss = CrossEntropyLoss(logits_ndim=3)

        self.total_loss = AggregatorLoss(
            num_inputs=2,
            weights=[cfg.intent_loss_weight, 1.0 - cfg.intent_loss_weight])

        # setup to track metrics
        self.intent_classification_report = ClassificationReport(
            self.data_desc.num_intents, self.data_desc.intents_label_ids)
        self.slot_classification_report = ClassificationReport(
            self.data_desc.num_slots, self.data_desc.slots_label_ids)

        # Optimizer setup needs to happen after all model weights are ready
        self.setup_optimization(cfg.optim)
Пример #5
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """
        Initializes BERT Punctuation and Capitalization model.
        """
        self.setup_tokenizer(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file',
                                               cfg.language_model.config_file),
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
            vocab_file=self.register_artifact('tokenizer.vocab_file',
                                              cfg.tokenizer.vocab_file),
        )

        self.punct_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.punct_label_ids),
            activation=cfg.punct_head.activation,
            log_softmax=False,
            dropout=cfg.punct_head.fc_dropout,
            num_layers=cfg.punct_head.punct_num_fc_layers,
            use_transformer_init=cfg.punct_head.use_transformer_init,
        )

        self.capit_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.capit_label_ids),
            activation=cfg.capit_head.activation,
            log_softmax=False,
            dropout=cfg.capit_head.fc_dropout,
            num_layers=cfg.capit_head.capit_num_fc_layers,
            use_transformer_init=cfg.capit_head.use_transformer_init,
        )

        self.loss = CrossEntropyLoss(logits_ndim=3)
        self.agg_loss = AggregatorLoss(num_inputs=2)

        # setup to track metrics
        self.punct_class_report = ClassificationReport(
            num_classes=len(self._cfg.punct_label_ids),
            label_ids=self._cfg.punct_label_ids,
            mode='macro',
            dist_sync_on_step=True,
        )
        self.capit_class_report = ClassificationReport(
            num_classes=len(self._cfg.capit_label_ids),
            label_ids=self._cfg.capit_label_ids,
            mode='macro',
            dist_sync_on_step=True,
        )
Пример #6
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes the BERTTextClassifier model."""

        # shared params for dataset and data loaders
        self.dataset_cfg = cfg.dataset
        # tokenizer needs to get initialized before the super.__init__()
        # as dataloaders and datasets need it to process the data
        self.setup_tokenizer(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=cfg.language_model.config,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.classifier = SequenceClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=cfg.dataset.num_classes,
            num_layers=cfg.classifier_head.num_output_layers,
            activation='relu',
            log_softmax=False,
            dropout=cfg.classifier_head.fc_dropout,
            use_transformer_init=True,
            idx_conditioned_on=0,
        )

        class_weights = None
        if cfg.dataset.class_balancing == 'weighted_loss':
            if cfg.train_ds.file_path:
                class_weights = calc_class_weights(cfg.train_ds.file_path,
                                                   cfg.dataset.num_classes)
            else:
                logging.info(
                    'Class_balancing feature is enabled but no train file is given. Calculating the class weights is skipped.'
                )

        if class_weights:
            # You may need to increase the number of epochs for convergence when using weighted_loss
            self.loss = CrossEntropyLoss(weight=class_weights)
        else:
            self.loss = CrossEntropyLoss()

        # setup to track metrics
        self.classification_report = ClassificationReport(
            num_classes=cfg.dataset.num_classes,
            mode='micro',
            dist_sync_on_step=True)

        # register the file containing the labels into the artifacts to get stored in the '.nemo' file later
        if 'class_labels' in cfg and 'class_labels_file' in cfg.class_labels and cfg.class_labels.class_labels_file:
            self.register_artifact('class_labels',
                                   cfg.class_labels.class_labels_file)
Пример #7
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes the BERTTextClassifier model."""

        # shared params for dataset and data loaders
        self.dataset_cfg = cfg.dataset
        # tokenizer needs to get initialized before the super.__init__()
        # as dataloaders and datasets need it to process the data
        self.setup_tokenizer(cfg.tokenizer)

        self.class_weights = None

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file',
                                               cfg.language_model.config_file),
            config_dict=cfg.language_model.config,
            checkpoint_file=cfg.language_model.lm_checkpoint,
            nemo_file=self.register_artifact(
                'language_model.nemo_file',
                cfg.language_model.get('nemo_file', None)),
            vocab_file=self.register_artifact('tokenizer.vocab_file',
                                              cfg.tokenizer.vocab_file),
            trainer=trainer,
        )

        if cfg.language_model.get('nemo_file', None) is not None:
            hidden_size = self.bert_model.cfg.hidden_size
        else:
            hidden_size = self.bert_model.config.hidden_size

        self.classifier = SequenceClassifier(
            hidden_size=hidden_size,
            num_classes=cfg.dataset.num_classes,
            num_layers=cfg.classifier_head.num_output_layers,
            activation='relu',
            log_softmax=False,
            dropout=cfg.classifier_head.fc_dropout,
            use_transformer_init=True,
            idx_conditioned_on=0,
        )

        self.create_loss_module()

        # setup to track metrics
        self.classification_report = ClassificationReport(
            num_classes=cfg.dataset.num_classes,
            mode='micro',
            dist_sync_on_step=True)

        # register the file containing the labels into the artifacts to get stored in the '.nemo' file later
        if 'class_labels' in cfg and 'class_labels_file' in cfg.class_labels and cfg.class_labels.class_labels_file:
            self.register_artifact('class_labels.class_labels_file',
                                   cfg.class_labels.class_labels_file)
Пример #8
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """
        Initializes model to use BERT model for GLUE tasks.
        """

        if cfg.task_name not in cfg.supported_tasks:
            raise ValueError(
                f'{cfg.task_name} not in supported task. Choose from {cfg.supported_tasks}'
            )
        self.task_name = cfg.task_name

        # needed to setup validation on multiple datasets
        # MNLI task has two separate dev sets: matched and mismatched
        if not self._is_model_being_restored():
            if self.task_name == "mnli":
                cfg.validation_ds.ds_item = [
                    os.path.join(cfg.dataset.data_dir, 'dev_matched.tsv'),
                    os.path.join(cfg.dataset.data_dir, 'dev_mismatched.tsv'),
                ]
            else:
                cfg.validation_ds.ds_item = os.path.join(
                    cfg.dataset.data_dir, cfg.validation_ds.ds_item)
            cfg.train_ds.ds_item = os.path.join(cfg.dataset.data_dir,
                                                cfg.train_ds.ds_item)
            logging.info(
                f'Using {cfg.validation_ds.ds_item} for model evaluation.')

        self.setup_tokenizer(cfg.tokenizer)
        super().__init__(cfg=cfg, trainer=trainer)

        num_labels = GLUE_TASKS_NUM_LABELS[self.task_name]

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file',
                                               cfg.language_model.config_file),
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
            vocab_file=self.register_artifact('tokenizer.vocab_file',
                                              cfg.tokenizer.vocab_file),
        )

        # uses [CLS] token for classification (the first token)
        if self.task_name == "sts-b":
            self.pooler = SequenceRegression(
                hidden_size=self.bert_model.config.hidden_size)
            self.loss = MSELoss()
        else:
            self.pooler = SequenceClassifier(
                hidden_size=self.bert_model.config.hidden_size,
                num_classes=num_labels,
                log_softmax=False)
            self.loss = CrossEntropyLoss()
Пример #9
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):

        self.data_prepared = False
        self.setup_tokenizer(cfg.tokenizer)
        super().__init__(cfg=cfg, trainer=trainer)
        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config) if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.encoder = SGDEncoder(hidden_size=self.bert_model.config.hidden_size, dropout=self._cfg.encoder.dropout)
        self.decoder = SGDDecoder(embedding_dim=self.bert_model.config.hidden_size)
        self.loss = SGDDialogueStateLoss(reduction="mean")
Пример #10
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes Token Classification Model."""
        # extract str to int labels mapping if a mapping file provided
        if isinstance(cfg.label_ids, str):
            if os.path.exists(cfg.label_ids):
                logging.info(
                    f'Reusing label_ids file found at {cfg.label_ids}.')
                label_ids = get_labels_to_labels_id_mapping(cfg.label_ids)
                # update the config to store name to id mapping
                cfg.label_ids = OmegaConf.create(label_ids)
            else:
                raise ValueError(f'{cfg.label_ids} not found.')

        self.setup_tokenizer(cfg.tokenizer)
        self.class_weights = None
        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file',
                                               cfg.language_model.config_file),
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
            vocab_file=self.register_artifact('tokenizer.vocab_file',
                                              cfg.tokenizer.vocab_file),
        )

        self.classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.label_ids),
            num_layers=self._cfg.head.num_fc_layers,
            activation=self._cfg.head.activation,
            log_softmax=False,
            dropout=self._cfg.head.fc_dropout,
            use_transformer_init=self._cfg.head.use_transformer_init,
        )

        self.loss = self.setup_loss(
            class_balancing=self._cfg.dataset.class_balancing)

        # setup to track metrics
        self.classification_report = ClassificationReport(
            len(self._cfg.label_ids),
            label_ids=self._cfg.label_ids,
            dist_sync_on_step=True)
Пример #11
0
 def get_lm_model_with_padded_embedding(cfg: DictConfig):
     """
     Function which ensures that vocabulary size is divisivble by 8
     for faster mixed precision training.
     """
     model = get_lm_model(
         pretrained_model_name=cfg.language_model.pretrained_model_name,
         config_file=cfg.language_model.config_file,
         config_dict=OmegaConf.to_container(cfg.language_model.config) if cfg.language_model.config else None,
         checkpoint_file=cfg.language_model.lm_checkpoint,
         vocab_file=cfg.tokenizer.vocab_file,
     )
     vocab_size, hidden_size = model.config.vocab_size, model.config.hidden_size
     tokens_to_add = 8 * math.ceil(vocab_size / 8) - vocab_size
     zeros = torch.zeros((tokens_to_add, hidden_size))
     model.embeddings.word_embeddings.weight.data = torch.cat((model.embeddings.word_embeddings.weight.data, zeros))
     return model
Пример #12
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes the SAP-BERT model for entity linking."""

        # tokenizer needed before super().__init__() so dataset and loader can process data
        self._setup_tokenizer(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)

        self.model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=cfg.language_model.config,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        # Token to use for the self-alignment loss, typically the first token, [CLS]
        self._idx_conditioned_on = 0
        self.loss = MultiSimilarityLoss()
Пример #13
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """ Initializes BERT Joint Intent and Slot model.
        """
        self.max_seq_length = cfg.language_model.max_seq_length

        # Setup tokenizer.
        self.setup_tokenizer(cfg.tokenizer)
        self.cfg = cfg
        # Check the presence of data_dir.
        if not cfg.data_dir or not os.path.exists(cfg.data_dir):
            # Disable setup methods.
            IntentSlotClassificationModel._set_model_restore_state(
                is_being_restored=True)
            # Set default values of data_desc.
            self._set_defaults_data_desc(cfg)
        else:
            self.data_dir = cfg.data_dir
            # Update configuration of data_desc.
            self._set_data_desc_to_cfg(cfg, cfg.data_dir, cfg.train_ds,
                                       cfg.validation_ds)

        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        # Initialize Bert model
        self.bert_model = get_lm_model(
            pretrained_model_name=self.cfg.language_model.
            pretrained_model_name,
            config_file=self.register_artifact('language_model.config_file',
                                               cfg.language_model.config_file),
            config_dict=OmegaConf.to_container(self.cfg.language_model.config)
            if self.cfg.language_model.config else None,
            checkpoint_file=self.cfg.language_model.lm_checkpoint,
            vocab_file=self.register_artifact('tokenizer.vocab_file',
                                              cfg.tokenizer.vocab_file),
        )

        # Enable setup methods.
        IntentSlotClassificationModel._set_model_restore_state(
            is_being_restored=False)

        # Initialize Classifier.
        self._reconfigure_classifier()
Пример #14
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes Token Classification Model."""

        self._setup_tokenizer(cfg.tokenizer)

        self._cfg = cfg
        self.data_desc = None
        self.update_data_dir(cfg.dataset.data_dir)
        self.setup_loss(class_balancing=self._cfg.dataset.class_balancing)

        super().__init__(cfg=cfg, trainer=trainer)
        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.label_ids),
            num_layers=self._cfg.head.num_fc_layers,
            activation=self._cfg.head.activation,
            log_softmax=self._cfg.head.log_softmax,
            dropout=self._cfg.head.fc_dropout,
            use_transformer_init=self._cfg.head.use_transformer_init,
        )

        self.loss = self.setup_loss(
            class_balancing=self._cfg.dataset.class_balancing)
        # setup to track metrics
        self.classification_report = ClassificationReport(
            len(self._cfg.label_ids),
            label_ids=self._cfg.label_ids,
            dist_sync_on_step=True)
Пример #15
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None, no_lm_init=False):

        self.hidden_size = None
        self.bert_model = None
        vocab_file = None
        nemo_file = None
        config_dict = None
        config_file = None

        # tokenizer needs to get initialized before the super.__init__()
        # as dataloaders and datasets need it to process the data
        pretrain_model_name = ''
        if cfg.get('language_model') and cfg.language_model.get('pretrained_model_name', ''):
            pretrain_model_name = cfg.language_model.get('pretrained_model_name', '')
        all_pretrained_megatron_bert_models = get_megatron_pretrained_bert_models()

        if cfg.get('tokenizer'):
            # Some models have their own tokenizer setup
            if (
                not hasattr(self, 'tokenizer')
                and cfg.tokenizer.get('tokenizer_name')
                and pretrain_model_name not in all_pretrained_megatron_bert_models
            ):
                self.setup_tokenizer(cfg.tokenizer)
            elif pretrain_model_name in all_pretrained_megatron_bert_models:
                copy_cfg = copy.deepcopy(cfg)
                bert_model = get_lm_model(
                    config_file=config_file,
                    config_dict=config_dict,
                    vocab_file=vocab_file,
                    trainer=trainer,
                    cfg=copy_cfg,
                )
                # set the tokenizer if it is not initialized explicitly
                if (
                    (hasattr(self, 'tokenizer') and self.tokenizer is None) or not hasattr(self, 'tokenizer')
                ) and hasattr(bert_model, 'tokenizer'):
                    self.tokenizer = bert_model.tokenizer
            if (
                cfg.get('tokenizer')
                and hasattr(cfg.get('tokenizer'), 'vocab_file')
                and cfg.get('tokenizer').get('vocab_file')
            ):
                vocab_file = self.register_artifact('tokenizer.vocab_file', cfg.tokenizer.vocab_file)
        super().__init__(cfg, trainer)

        # handles model parallel save and restore logic
        self._save_restore_connector = NLPSaveRestoreConnector()

        if cfg.get('language_model') and not no_lm_init:
            if cfg.get('language_model').get('nemo_file'):
                nemo_file = self.register_artifact('language_model.nemo_file', cfg.language_model.nemo_file)
            if cfg.get('language_model').get('config'):
                config_dict = OmegaConf.to_container(cfg.language_model.config)
            if cfg.get('language_model').get('config_file'):
                config_file = self.register_artifact('language_model.config_file', cfg.language_model.config_file)
            bert_model = get_lm_model(
                config_file=config_file, config_dict=config_dict, vocab_file=vocab_file, trainer=trainer, cfg=cfg,
            )
            # set the tokenizer if it is not initialized explicitly
            if ((hasattr(self, 'tokenizer') and self.tokenizer is None) or not hasattr(self, 'tokenizer')) and hasattr(
                bert_model, 'tokenizer'
            ):
                self.tokenizer = bert_model.tokenizer

            # Required to pull up the config for MegatronBert models
            self.pretrained_model_name = cfg.language_model.pretrained_model_name

            # register encoder config
            self.register_bert_model()

            if (
                cfg.tokenizer is not None
                and cfg.tokenizer.get("tokenizer_name", "") is not None
                and "megatron" in cfg.tokenizer.get("tokenizer_name", "")
            ) or pretrain_model_name in all_pretrained_megatron_bert_models:
                self.hidden_size = bert_model.cfg.hidden_size
            else:
                self.hidden_size = bert_model.config.hidden_size

        if cfg.get('language_model') and not no_lm_init:
            self.bert_model = bert_model