class NeuralMachineTranslationModel(ModelPT):
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return {
            "input_ids":
            NeuralType(('B', 'T'), ChannelType()),
            "attention_mask":
            NeuralType(('B', 'T'), MaskType(), optional=True),
            "decoder_input_ids":
            NeuralType(('B', 'T'), ChannelType(), optional=True),
            "labels":
            NeuralType(('B', 'T'), ChannelType(), optional=True),
        }

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return {
            "loss":
            NeuralType((), LossType()),
            "decoder_hidden_states":
            NeuralType(("B", "T", "D"), ChannelType(), optional=True),
            "encoder_hidden_states":
            NeuralType(("B", "T", "D"), ChannelType(), optional=True),
        }

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):

        # must assign tokenizers before init
        if cfg.language_model.pretrained_model_name:
            if cfg.language_model.pretrained_encoder_model_name or cfg.language_model.pretrained_decoder_model_name:
                raise ValueError(
                    "Must have either pretrained_model_name or both pretrained_encoder_model name and "
                    "pretrained_decoder_model_name.")
            # setup tokenizer
            self.encoder_tokenizer = self.setup_tokenizer(
                cfg.encoder_tokenizer)
            self.encoder_add_special_tokens = cfg.encoder_tokenizer.add_special_tokens

            # set decoder to encoder
            self.decoder_tokenizer = self.encoder_tokenizer
            self.decoder_add_special_tokens = self.encoder_add_special_tokens
        else:
            if not (cfg.language_model.pretrained_encoder_model_name
                    and cfg.language_model.pretrained_decoder_model_name):
                raise ValueError("Both encoder and decoder must be specified")

            # setup tokenizers
            self.encoder_tokenizer = self.setup_tokenizer(
                cfg.encoder_tokenizer)
            self.encoder_add_special_tokens = cfg.encoder_tokenizer.add_special_tokens

            self.decoder_tokenizer = self.setup_tokenizer(
                cfg.decoder_tokenizer)
            self.decoder_add_special_tokens = cfg.decoder_tokenizer.add_special_tokens

        if not self.encoder_tokenizer:
            raise TypeError("encoder_tokenizer failed to initialize")
        if not self.decoder_tokenizer:
            raise TypeError("decoder_tokenizer failed to initialize")

        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        # must assign modules after init
        if cfg.language_model.pretrained_model_name:
            # Setup end-to-end model
            if "bart" in cfg.language_model.pretrained_model_name:
                self.model = BartForConditionalGeneration.from_pretrained(
                    cfg.language_model.pretrained_model_name)
            else:
                self.model = AutoModel.from_pretrained(
                    cfg.language_model.pretrained_model_name)
        else:
            if not (cfg.language_model.pretrained_encoder_model_name
                    and cfg.language_model.pretrained_decoder_model_name):
                raise ValueError("Both encoder and decoder must be specified")

            # Setup encoder/decoder model
            self.model = EncoderDecoderModel.from_encoder_decoder_pretrained(
                encoder=cfg.language_model.pretrained_encoder_model_name,
                decoder=cfg.language_model.pretrained_decoder_model_name,
            )

        self.validation_perplexity = Perplexity(compute_on_step=False)

        self.setup_optimization(cfg.optim)

    @typecheck()
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor = None,
        decoder_input_ids: torch.Tensor = None,
        labels: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            labels=labels,
            return_dict=False,
        )
        return outputs

    @typecheck.disable_checks()
    def generate(self, input_ids: torch.Tensor) -> torch.Tensor:
        """Wraps huggingface EncoderDecoder.generate()."""
        outputs = self.model.generate(
            input_ids=input_ids,
            pad_token_id=self.encoder_tokenizer.pad_id,
            bos_token_id=self.encoder_tokenizer.bos_id,
            eos_token_id=self.encoder_tokenizer.eos_id,
            decoder_start_token_id=self.decoder_tokenizer.bos_id,
            **self._cfg.generate,
        )
        return outputs

    def training_step(self, batch: Tuple, batch_idx: int) -> Dict:
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`. Loss calculation from HuggingFace's BartForConditionalGeneration.
        """
        input_ids, input_mask, decoder_input_ids, labels = batch
        loss = self.forward(
            input_ids=input_ids,
            attention_mask=input_mask,
            decoder_input_ids=decoder_input_ids,
            labels=labels,
        )[0]

        tensorboard_logs = {
            "train_loss": loss,
            "lr": self._optimizer.param_groups[0]["lr"]
        }

        return {"loss": loss, "log": tensorboard_logs}

    def validation_step(self, batch: Tuple, batch_idx: int) -> Dict:
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`. Loss calculation from HuggingFace's BartForConditionalGeneration.
        """
        input_ids, input_mask, decoder_input_ids, labels = batch
        loss, logits = self.forward(
            input_ids=input_ids,
            attention_mask=input_mask,
            decoder_input_ids=decoder_input_ids,
            labels=labels,
        )[:2]

        self.validation_perplexity(logits=logits)

        tensorboard_logs = {"val_loss": loss}

        return {"val_loss": loss, "log": tensorboard_logs}

    def validation_epoch_end(self, outputs: List[Dict]) -> Dict:
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        perplexity = self.validation_perplexity.compute()
        tensorboard_logs = {"val_loss": avg_loss, "perplexity": perplexity}
        logging.info(f"evaluation perplexity {perplexity.item()}")
        return {"val_loss": avg_loss, "log": tensorboard_logs}

    @typecheck.disable_checks()
    def test_step(self, batch: Tuple, batch_idx: int) -> torch.Tensor:
        """Lightning calls this inside the test loop with data from the test dataloader."""
        input_ids, input_mask, decoder_input_ids, labels = batch
        sequences = self.generate(input_ids=input_ids)
        return sequences

    @typecheck.disable_checks()
    def test_epoch_end(self,
                       outputs: List[torch.Tensor]) -> Dict[str, List[str]]:
        """Called at the end of test to aggregate outputs and decode them."""
        texts = [
            self.encoder_tokenizer.ids_to_text(seq) for batch in outputs
            for seq in batch
        ]
        return {"texts": texts}

    def setup_tokenizer(self, cfg: DictConfig):
        tokenizer = get_tokenizer(
            tokenizer_name=cfg.tokenizer_name,
            tokenizer_model=cfg.tokenizer_model,
            special_tokens=OmegaConf.to_container(cfg.special_tokens)
            if cfg.special_tokens else None,
            vocab_file=cfg.vocab_file,
        )
        return tokenizer

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = self.setup_dataloader_from_config(
            cfg=train_data_config)

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = self.setup_dataloader_from_config(
            cfg=val_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        self._test_dl = self.setup_dataloader_from_config(cfg=test_data_config)

    def setup_dataloader_from_config(self, cfg: DictConfig):
        dataset = NeuralMachineTranslationDataset(
            filepath=cfg.filepath,
            encoder_tokenizer=self.encoder_tokenizer,
            decoder_tokenizer=self.decoder_tokenizer,
            encoder_add_special_tokens=self.encoder_add_special_tokens,
            decoder_add_special_tokens=self.decoder_add_special_tokens,
            max_seq_length=self._cfg.max_seq_length,
            num_samples=cfg.get("num_samples", -1),
            convert_labels=self._cfg.convert_labels,
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=self._cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=cfg.get("num_workers", 2),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=cfg.get("drop_last", False),
            collate_fn=dataset.collate_fn,
        )

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        pass
示例#2
0
def _perplexity_class_test(
    rank: int,
    worldsize: int,
    probs: Optional[torch.Tensor],
    logits: Optional[torch.Tensor],
    dist_sync_on_step: bool,
    metric_args: dict = {},
    check_dist_sync_on_step: bool = True,
    check_batch: bool = True,
    atol: float = 1e-8,
):
    """ Utility function doing the actual comparison between lightning class metric
        and reference metric.
        Args:
            rank: rank of current process
            worldsize: number of processes
            probs: torch tensor with probabilities
            logits: torch tensor with logits. The function checks ``probs`` and ``logits are mutually exclusive for
                ``Perplexity`` metric.
            dist_sync_on_step: bool, if true will synchronize metric state across
                processes at each ``forward()``
            metric_args: dict with additional arguments used for class initialization
            check_dist_sync_on_step: bool, if true will check if the metric is also correctly
                calculated per batch per device (and not just at the end)
            check_batch: bool, if true will check if the metric is also correctly
                calculated across devices for each batch (and not just at the end)
    """
    # Instanciate lightning metric
    perplexity = Perplexity(compute_on_step=True,
                            dist_sync_on_step=dist_sync_on_step,
                            **metric_args)
    if (probs is None) == (logits is None):
        with pytest.raises(ValueError):
            perplexity(probs, logits)
        return

    # verify perplexity works after being loaded from pickled state
    pickled_metric = pickle.dumps(perplexity)
    perplexity = pickle.loads(pickled_metric)

    for i in range(rank, NUM_BATCHES, worldsize):
        batch_result = perplexity(None if probs is None else probs[i],
                                  None if logits is None else logits[i])

        if perplexity.dist_sync_on_step:
            if rank == 0:
                if probs is not None:
                    ddp_probs = torch.stack(
                        [probs[i + r] for r in range(worldsize)])
                else:
                    ddp_logits = torch.stack(
                        [logits[i + r] for r in range(worldsize)])
                    ddp_probs = logits_to_probs(ddp_logits, is_binary=False)
                sk_batch_result = reference_perplexity_func(ddp_probs)
                # assert for dist_sync_on_step
                if check_dist_sync_on_step:
                    assert np.allclose(batch_result.numpy(),
                                       sk_batch_result,
                                       atol=atol)
        else:
            if probs is None:
                p = logits_to_probs(logits[i], is_binary=False)
            else:
                p = probs[i]
            sk_batch_result = reference_perplexity_func(p)
            # assert for batch
            if check_batch:
                assert np.allclose(batch_result.numpy(),
                                   sk_batch_result,
                                   atol=atol)

    assert (probs is None) != (logits is None)
    # check on all batches on all ranks
    result = perplexity.compute()
    assert isinstance(result, torch.Tensor)

    if probs is None:
        probs = logits_to_probs(logits, is_binary=False)
    sk_result = reference_perplexity_func(probs)

    # assert after aggregation
    assert np.allclose(result.numpy(), sk_result, atol=atol)
示例#3
0
class TransformerLMModel(ModelPT):
    """
    Left-to-right Transformer language model.
    """
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):

        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        self.global_rank = 0
        self.world_size = 1
        if trainer is not None:
            self.global_rank = (trainer.node_rank *
                                trainer.num_gpus) + trainer.local_rank
            self.world_size = trainer.num_nodes * trainer.num_gpus

        # shared params for dataset and data loaders
        self.dataset_cfg = cfg.dataset
        self.tokenizer = get_tokenizer(
            tokenizer_name=cfg.language_model.tokenizer,
            vocab_file=cfg.language_model.vocab_file,
            special_tokens=cfg.language_model.special_tokens,
        )

        # make vocabulary size divisible by 8 for fast fp16 training
        vocab_size = 8 * math.ceil(self.tokenizer.vocab_size / 8)

        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        self.embedding_layer = TransformerEmbedding(
            vocab_size=vocab_size,
            hidden_size=cfg.language_model.hidden_size,
            max_sequence_length=cfg.language_model.max_seq_length,
            embedding_dropout=cfg.language_model.get("embedding_dropout", 0.0),
            learn_positional_encodings=False,
        )
        self.encoder = TransformerEncoder(
            num_layers=cfg.language_model.num_layers,
            hidden_size=cfg.language_model.hidden_size,
            mask_future=True,
            num_attention_heads=cfg.language_model.num_attn_heads,
            inner_size=cfg.language_model.inner_size,
            ffn_dropout=cfg.language_model.get("ffn_dropout", 0.0),
            hidden_act=cfg.language_model.get("inner_activation", "relu"),
            attn_score_dropout=cfg.language_model.get("attn_score_dropout",
                                                      0.0),
            attn_layer_dropout=cfg.language_model.get("attn_layer_dropout",
                                                      0.0),
        )
        self.log_softmax = TokenClassifier(
            hidden_size=cfg.language_model.hidden_size,
            num_classes=vocab_size,
            log_softmax=True,
        )

        std_init_range = 1 / math.sqrt(cfg.language_model.hidden_size)
        self.apply(
            lambda module: transformer_weights_init(module, std_init_range))

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.embedding_layer.token_embedding.weight

        self.training_loss = SmoothedCrossEntropyLoss(
            pad_id=self.tokenizer.pad_id)
        self.validation_loss = SmoothedCrossEntropyLoss(
            pad_id=self.tokenizer.pad_id,
            predict_last_k=self.dataset_cfg.get("predict_last_k", 0),
        )

        self.training_perplexity = Perplexity(dist_sync_on_step=True)
        self.validation_perplexity = Perplexity(compute_on_step=False)

        # Optimizer setup needs to happen after all model weights are ready
        self.setup_optimization(cfg.optim)

    @typecheck()
    def forward(self, input_ids, attention_mask):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        token_embeddings = self.embedding_layer(input_ids)
        hidden_states = self.encoder(token_embeddings, attention_mask)
        log_probs = self.log_softmax(hidden_states=hidden_states)

        return log_probs

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        input_ids, input_mask, labels = batch
        log_probs = self(input_ids=input_ids, attention_mask=input_mask)

        train_loss = self.training_loss(log_probs=log_probs, labels=labels)
        training_perplexity = self.training_perplexity(logits=log_probs)

        tensorboard_logs = {
            "train_loss": train_loss,
            "lr": self._optimizer.param_groups[0]["lr"],
            "train_ppl": training_perplexity,
        }
        return {"loss": train_loss, "log": tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids, input_mask, labels = batch
        log_probs = self(input_ids=input_ids, attention_mask=input_mask)

        val_loss = self.validation_loss(log_probs=log_probs, labels=labels)
        self.validation_perplexity(logits=log_probs)

        tensorboard_logs = {"val_loss": val_loss}

        return {"val_loss": val_loss, "log": tensorboard_logs}

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """

        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        validation_perplexity = self.validation_perplexity.compute()
        tensorboard_logs = {
            "val_loss": avg_loss,
            "val_ppl": validation_perplexity
        }
        return {"val_loss": avg_loss, "log": tensorboard_logs}

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config)

        # Need to set this because if using an IterableDataset, the length of the dataloader is the total number
        # of samples rather than the number of batches, and this messes up the tqdm progress bar.
        # So we set the number of steps manually (to the correct number) to fix this.
        if 'is_tarred' in train_data_config and train_data_config['is_tarred']:
            # We also need to check if limit_train_batches is already set.
            # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches,
            # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0).
            if isinstance(self._trainer.limit_train_batches, float):
                self._trainer.limit_train_batches = int(
                    self._trainer.limit_train_batches * math.ceil(
                        (len(self._train_dl.dataset) / self.world_size) /
                        train_data_config['batch_size']))

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config,
            predict_last_k=self.dataset_cfg.get("predict_last_k", 0),
        )

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config,
            predict_last_k=self.dataset_cfg.get("predict_last_k", 0),
        )

    def _setup_dataloader_from_config(self, cfg: DictConfig, predict_last_k=0):
        if cfg.get('use_cache', False):
            logging.info("Constructing tokenized dataset cache...")

        shuffle = cfg.shuffle

        if cfg.get('is_tarred', False):
            if ('tarred_text_filepaths' in cfg
                    and cfg['tarred_text_filepaths'] is None) or (
                        'file_name' in cfg and cfg['file_name'] is None):
                logging.warning(
                    "Could not load dataset as `file_name` was None or "
                    f"`tarred_text_filepaths` is None. Provided config : {cfg}"
                )
                return None

            shuffle_n = cfg.get('shuffle_n', 4 *
                                cfg['batch_size']) if shuffle else 0
            dataset = TarredL2RLanguageModelingDataset(
                text_tar_filepaths=cfg['tarred_text_filepaths'],
                metadata_path=cfg['file_name'],
                tokenizer=self.tokenizer,
                max_seq_length=self.dataset_cfg.max_seq_length,
                batch_step=predict_last_k,
                shuffle_n=shuffle_n,
                shard_strategy=cfg.get("tarred_shard_strategy", "scatter"),
                global_rank=self.global_rank,
                world_size=self.world_size,
            )

            shuffle = False
        else:

            dataset = L2RLanguageModelingDataset(
                tokenizer=self.tokenizer,
                dataset=cfg.file_name,
                max_seq_length=self.dataset_cfg.max_seq_length,
                batch_step=predict_last_k,
                use_cache=cfg.get('use_cache', False),
            )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            shuffle=shuffle,
            num_workers=self.dataset_cfg.get("num_workers", 2),
            pin_memory=self.dataset_cfg.get("pin_memory", False),
            drop_last=self.dataset_cfg.get("drop_last", False),
        )

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        pass
示例#4
0
class BERTLMModel(ModelPT):
    """
    BERT language model pretraining.
    """
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        output_types_dict = {
            "mlm_log_probs": self.mlm_classifier.output_types["log_probs"]
        }
        if not self.only_mlm_loss:
            output_types_dict["nsp_logits"] = self.nsp_classifier.output_types[
                "logits"]
        return output_types_dict

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):

        if cfg.tokenizer is not None:
            self._setup_tokenizer(cfg.tokenizer)
        else:
            self.tokenizer = None

        super().__init__(cfg=cfg, trainer=trainer)

        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config)
            if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.hidden_size = self.bert_model.config.hidden_size
        self.vocab_size = self.bert_model.config.vocab_size
        self.only_mlm_loss = cfg.only_mlm_loss

        self.mlm_classifier = BertPretrainingTokenClassifier(
            hidden_size=self.hidden_size,
            num_classes=self.vocab_size,
            num_layers=cfg.num_tok_classification_layers,
            activation="gelu",
            log_softmax=True,
            use_transformer_init=True,
        )

        self.mlm_loss = SmoothedCrossEntropyLoss()

        if not self.only_mlm_loss:
            self.nsp_classifier = SequenceClassifier(
                hidden_size=self.hidden_size,
                num_classes=2,
                num_layers=cfg.num_seq_classification_layers,
                log_softmax=False,
                activation="tanh",
                use_transformer_init=True,
            )

            self.nsp_loss = CrossEntropyLoss()
            self.agg_loss = AggregatorLoss(num_inputs=2)

        # # tie weights of MLM softmax layer and embedding layer of the encoder
        if (self.mlm_classifier.mlp.last_linear_layer.weight.shape !=
                self.bert_model.embeddings.word_embeddings.weight.shape):
            raise ValueError(
                "Final classification layer does not match embedding layer.")
        self.mlm_classifier.mlp.last_linear_layer.weight = self.bert_model.embeddings.word_embeddings.weight
        # create extra bias

        # setup to track metrics
        self.validation_perplexity = Perplexity(compute_on_step=False)

        self.setup_optimization(cfg.optim)

    @typecheck()
    def forward(self, input_ids, token_type_ids, attention_mask):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
        )
        mlm_log_probs = self.mlm_classifier(hidden_states=hidden_states)
        if self.only_mlm_loss:
            return (mlm_log_probs, )
        nsp_logits = self.nsp_classifier(hidden_states=hidden_states)
        return mlm_log_probs, nsp_logits

    def _compute_losses(self, mlm_log_probs, nsp_logits, output_ids,
                        output_mask, labels):
        mlm_loss = self.mlm_loss(log_probs=mlm_log_probs,
                                 labels=output_ids,
                                 output_mask=output_mask)
        if self.only_mlm_loss:
            loss, nsp_loss = mlm_loss, None
        else:
            nsp_loss = self.nsp_loss(logits=nsp_logits, labels=labels)
            loss = self.agg_loss(loss_1=mlm_loss, loss_2=nsp_loss)
        return mlm_loss, nsp_loss, loss

    def _parse_forward_outputs(self, forward_outputs):
        if self.only_mlm_loss:
            return forward_outputs[0], None
        else:
            return forward_outputs

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        input_ids, input_type_ids, input_mask, output_ids, output_mask, labels = batch
        forward_outputs = self.forward(input_ids=input_ids,
                                       token_type_ids=input_type_ids,
                                       attention_mask=input_mask)
        mlm_log_probs, nsp_logits = self._parse_forward_outputs(
            forward_outputs)
        _, _, loss = self._compute_losses(mlm_log_probs, nsp_logits,
                                          output_ids, output_mask, labels)
        tensorboard_logs = {"train_loss": loss}
        return {"loss": loss, "log": tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, output_ids, output_mask, labels = batch
        forward_outputs = self.forward(input_ids=input_ids,
                                       token_type_ids=input_type_ids,
                                       attention_mask=input_mask)
        mlm_log_probs, nsp_logits = self._parse_forward_outputs(
            forward_outputs)
        _, _, loss = self._compute_losses(mlm_log_probs, nsp_logits,
                                          output_ids, output_mask, labels)
        self.validation_perplexity(logits=mlm_log_probs)
        tensorboard_logs = {'val_loss': loss}
        return {'val_loss': loss, 'log': tensorboard_logs}

    def validation_epoch_end(self, outputs):
        """Called at the end of validation to aggregate outputs.

        Args:
            outputs (list): The individual outputs of each validation step.

        Returns:
            dict: Validation loss and tensorboard logs.
        """
        if outputs:
            avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
            perplexity = self.validation_perplexity.compute()
            tensorboard_logs = {'val_loss': avg_loss, 'perplexity': perplexity}
            logging.info(f"evaluation perplexity {perplexity.cpu().item()}")
            return {'val_loss': avg_loss, 'log': tensorboard_logs}

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = (
            self._setup_preprocessed_dataloader(train_data_config)
            if self.tokenizer is None else
            self._setup_dataloader(train_data_config))

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = (
            self._setup_preprocessed_dataloader(val_data_config)
            if self.tokenizer is None else
            self._setup_dataloader(val_data_config))

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        pass

    def _setup_preprocessed_dataloader(self, cfg: Optional[DictConfig]):
        dataset = cfg.data_file
        max_predictions_per_seq = cfg.max_predictions_per_seq
        batch_size = cfg.batch_size

        if os.path.isdir(dataset):
            files = [
                os.path.join(dataset, f) for f in os.listdir(dataset)
                if os.path.isfile(os.path.join(dataset, f))
            ]
        else:
            files = [dataset]
        files.sort()
        dl = BertPretrainingPreprocessedDataloader(
            data_files=files,
            max_predictions_per_seq=max_predictions_per_seq,
            batch_size=batch_size,
        )
        return dl

    def _setup_tokenizer(self, cfg: DictConfig):
        tokenizer = get_tokenizer(
            tokenizer_name=cfg.tokenizer_name,
            tokenizer_model=cfg.tokenizer_model,
            special_tokens=OmegaConf.to_container(cfg.special_tokens)
            if cfg.special_tokens else None,
            vocab_file=cfg.vocab_file,
        )
        self.tokenizer = tokenizer

    def _setup_dataloader(self, cfg: DictConfig):
        dataset = BertPretrainingDataset(
            tokenizer=self.tokenizer,
            data_file=cfg.data_file,
            max_seq_length=cfg.max_seq_length,
            mask_prob=cfg.mask_prob,
            short_seq_prob=cfg.short_seq_prob,
        )
        dl = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            collate_fn=dataset.collate_fn,
            drop_last=cfg.get("drop_last", False),
            shuffle=cfg.shuffle,
            num_workers=cfg.get("num_workers", 0),
        )
        return dl

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        pass
示例#5
0
class BERTLMModel(ModelPT):
    """
    BERT language model pretraining.
    """
    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return self.bert_model.input_types

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        output_types_dict = {
            "mlm_log_probs": self.mlm_classifier.output_types["log_probs"]
        }
        if not self.only_mlm_loss:
            output_types_dict["nsp_logits"] = self.nsp_classifier.output_types[
                "logits"]
        return output_types_dict

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):

        vocab_file = None
        config_dict = None
        config_file = None

        if cfg.tokenizer is not None:
            if cfg.tokenizer.get('tokenizer_name') and cfg.tokenizer.get(
                    'tokenizer_model'):
                self._setup_tokenizer(cfg.tokenizer)
            if cfg.get('tokenizer.vocab_file'):
                vocab_file = self.register_artifact('tokenizer.vocab_file',
                                                    cfg.tokenizer.vocab_file)
        else:
            self.tokenizer = None

        super().__init__(cfg=cfg, trainer=trainer)

        if cfg.get('language_model.config'):
            config_dict = OmegaConf.to_container(cfg.language_model.config)
        if cfg.get('language_model.config_file'):
            config_file = self.register_artifact(
                'language_model.config_file', cfg.language_model.config_file)

        self.bert_model = get_lm_model(
            config_file=config_file,
            config_dict=config_dict,
            vocab_file=vocab_file,
            trainer=trainer,
            cfg=cfg,
        )

        self.hidden_size = self.bert_model.config.hidden_size
        self.vocab_size = self.bert_model.config.vocab_size
        self.only_mlm_loss = cfg.only_mlm_loss

        self.mlm_classifier = BertPretrainingTokenClassifier(
            hidden_size=self.hidden_size,
            num_classes=self.vocab_size,
            num_layers=cfg.num_tok_classification_layers,
            activation="gelu",
            log_softmax=True,
            use_transformer_init=True,
        )

        self.mlm_loss = SmoothedCrossEntropyLoss()

        if not self.only_mlm_loss:
            self.nsp_classifier = SequenceClassifier(
                hidden_size=self.hidden_size,
                num_classes=2,
                num_layers=cfg.num_seq_classification_layers,
                log_softmax=False,
                activation="tanh",
                use_transformer_init=True,
            )

            self.nsp_loss = CrossEntropyLoss()
            self.agg_loss = AggregatorLoss(num_inputs=2)

        # # tie weights of MLM softmax layer and embedding layer of the encoder
        if (self.mlm_classifier.mlp.last_linear_layer.weight.shape !=
                self.bert_model.embeddings.word_embeddings.weight.shape):
            raise ValueError(
                "Final classification layer does not match embedding layer.")
        self.mlm_classifier.mlp.last_linear_layer.weight = self.bert_model.embeddings.word_embeddings.weight
        # create extra bias

        # setup to track metrics
        self.validation_perplexity = Perplexity(compute_on_step=False)

        self.setup_optimization(cfg.optim)

    @typecheck()
    def forward(self, input_ids, attention_mask, token_type_ids):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
        )
        if isinstance(hidden_states, tuple):
            hidden_states = hidden_states[0]

        mlm_log_probs = self.mlm_classifier(hidden_states=hidden_states)
        if self.only_mlm_loss:
            return (mlm_log_probs, )
        nsp_logits = self.nsp_classifier(hidden_states=hidden_states)
        return mlm_log_probs, nsp_logits

    def _compute_losses(self, mlm_log_probs, nsp_logits, output_ids,
                        output_mask, labels):
        mlm_loss = self.mlm_loss(log_probs=mlm_log_probs,
                                 labels=output_ids,
                                 output_mask=output_mask)
        if self.only_mlm_loss:
            loss, nsp_loss = mlm_loss, None
        else:
            nsp_loss = self.nsp_loss(logits=nsp_logits, labels=labels)
            loss = self.agg_loss(loss_1=mlm_loss, loss_2=nsp_loss)
        return mlm_loss, nsp_loss, loss

    def _parse_forward_outputs(self, forward_outputs):
        if self.only_mlm_loss:
            return forward_outputs[0], None
        else:
            return forward_outputs

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, output_ids, output_mask, labels = batch
        forward_outputs = self.forward(input_ids=input_ids,
                                       token_type_ids=input_type_ids,
                                       attention_mask=input_mask)
        mlm_log_probs, nsp_logits = self._parse_forward_outputs(
            forward_outputs)
        _, _, loss = self._compute_losses(mlm_log_probs, nsp_logits,
                                          output_ids, output_mask, labels)
        lr = self._optimizer.param_groups[0]['lr']
        self.log('train_loss', loss)
        self.log('lr', lr, prog_bar=True)
        return {"loss": loss, "lr": lr}

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids, input_type_ids, input_mask, output_ids, output_mask, labels = batch
        forward_outputs = self.forward(input_ids=input_ids,
                                       token_type_ids=input_type_ids,
                                       attention_mask=input_mask)
        mlm_log_probs, nsp_logits = self._parse_forward_outputs(
            forward_outputs)
        _, _, loss = self._compute_losses(mlm_log_probs, nsp_logits,
                                          output_ids, output_mask, labels)
        self.validation_perplexity(logits=mlm_log_probs)
        return {'val_loss': loss}

    def validation_epoch_end(self, outputs):
        """Called at the end of validation to aggregate outputs.

        Args:
            outputs (list): The individual outputs of each validation step.

        Returns:
            dict: Validation loss and tensorboard logs.
        """
        if outputs:
            avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
            perplexity = self.validation_perplexity.compute()
            logging.info(f"evaluation perplexity {perplexity.cpu().item()}")
            self.log(f'val_loss', avg_loss)

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = (
            self._setup_preprocessed_dataloader(train_data_config)
            if self.tokenizer is None else
            self._setup_dataloader(train_data_config))

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = (
            self._setup_preprocessed_dataloader(val_data_config)
            if self.tokenizer is None else
            self._setup_dataloader(val_data_config))

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        pass

    def _setup_preprocessed_dataloader(self, cfg: Optional[DictConfig]):
        dataset = cfg.data_file
        max_predictions_per_seq = cfg.max_predictions_per_seq
        batch_size = cfg.batch_size

        if os.path.isdir(dataset):
            files = [
                os.path.join(dataset, f) for f in os.listdir(dataset)
                if os.path.isfile(os.path.join(dataset, f))
            ]
        else:
            files = [dataset]
        files.sort()
        dl = BertPretrainingPreprocessedDataloader(
            data_files=files,
            max_predictions_per_seq=max_predictions_per_seq,
            batch_size=batch_size,
        )
        return dl

    def _setup_tokenizer(self, cfg: DictConfig):
        tokenizer = get_tokenizer(
            tokenizer_name=cfg.tokenizer_name,
            tokenizer_model=cfg.tokenizer_model,
            special_tokens=OmegaConf.to_container(cfg.special_tokens)
            if cfg.special_tokens else None,
            vocab_file=cfg.vocab_file,
        )
        self.tokenizer = tokenizer

    def _setup_dataloader(self, cfg: DictConfig):
        dataset = BertPretrainingDataset(
            tokenizer=self.tokenizer,
            data_file=cfg.data_file,
            max_seq_length=cfg.max_seq_length,
            mask_prob=cfg.mask_prob,
            short_seq_prob=cfg.short_seq_prob,
        )
        dl = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            collate_fn=dataset.collate_fn,
            drop_last=cfg.get("drop_last", False),
            shuffle=cfg.shuffle,
            num_workers=cfg.get("num_workers", 0),
        )
        return dl

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []

        result.append(
            PretrainedModelInfo(
                pretrained_model_name="bertbaseuncased",
                location=
                "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/bertbaseuncased/versions/1.0.0rc1/files/bertbaseuncased.nemo",
                description=
                "The model was trained EN Wikipedia and BookCorpus on a sequence length of 512.",
            ))

        result.append(
            PretrainedModelInfo(
                pretrained_model_name="bertlargeuncased",
                location=
                "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/bertlargeuncased/versions/1.0.0rc1/files/bertlargeuncased.nemo",
                description=
                "The model was trained EN Wikipedia and BookCorpus on a sequence length of 512.",
            ))
        return result
示例#6
0
class TransformerLMModel(ModelPT):
    """
    Left-to-right Transformer language model.
    """
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):

        # shared params for dataset and data loaders
        self.dataset_cfg = cfg.dataset
        self.tokenizer = get_tokenizer(
            tokenizer_name=cfg.language_model.tokenizer,
            vocab_file=cfg.language_model.vocab_file,
            special_tokens=cfg.language_model.special_tokens,
        )

        # make vocabulary size divisible by 8 for fast fp16 training
        vocab_size = 8 * math.ceil(self.tokenizer.vocab_size / 8)

        # init superclass
        super().__init__(cfg=cfg, trainer=trainer)

        self.embedding_layer = TransformerEmbedding(
            vocab_size=vocab_size,
            hidden_size=cfg.language_model.hidden_size,
            max_sequence_length=cfg.language_model.max_seq_length,
            embedding_dropout=cfg.language_model.get("embedding_dropout", 0.0),
            learn_positional_encodings=False,
        )
        self.encoder = TransformerEncoder(
            num_layers=cfg.language_model.num_layers,
            hidden_size=cfg.language_model.hidden_size,
            mask_future=True,
            num_attention_heads=cfg.language_model.num_attn_heads,
            inner_size=cfg.language_model.inner_size,
            ffn_dropout=cfg.language_model.get("ffn_dropout", 0.0),
            hidden_act=cfg.language_model.get("inner_activation", "relu"),
            attn_score_dropout=cfg.language_model.get("attn_score_dropout",
                                                      0.0),
            attn_layer_dropout=cfg.language_model.get("attn_layer_dropout",
                                                      0.0),
        )
        self.log_softmax = TokenClassifier(
            hidden_size=cfg.language_model.hidden_size,
            num_classes=vocab_size,
            log_softmax=True,
        )

        std_init_range = 1 / math.sqrt(cfg.language_model.hidden_size)
        self.apply(
            lambda module: transformer_weights_init(module, std_init_range))

        # tie weights of embedding and softmax matrices
        self.log_softmax.mlp.layer0.weight = self.embedding_layer.token_embedding.weight

        self.training_loss = SmoothedCrossEntropyLoss(
            pad_id=self.tokenizer.pad_id)
        self.validation_loss = SmoothedCrossEntropyLoss(
            pad_id=self.tokenizer.pad_id,
            predict_last_k=self.dataset_cfg.get("predict_last_k", 0),
        )

        self.training_perplexity = Perplexity(dist_sync_on_step=True)
        self.validation_perplexity = Perplexity(compute_on_step=False)

        # Optimizer setup needs to happen after all model weights are ready
        self.setup_optimization(cfg.optim)

    @typecheck()
    def forward(self, input_ids, attention_mask):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        token_embeddings = self.embedding_layer(input_ids)
        hidden_states = self.encoder(token_embeddings, attention_mask)
        log_probs = self.log_softmax(hidden_states=hidden_states)

        return log_probs

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # forward pass
        input_ids, input_mask, labels = batch
        log_probs = self(input_ids=input_ids, attention_mask=input_mask)

        train_loss = self.training_loss(log_probs=log_probs, labels=labels)
        training_perplexity = self.training_perplexity(logits=log_probs)

        tensorboard_logs = {
            "train_loss": train_loss,
            "lr": self._optimizer.param_groups[0]["lr"],
            "train_ppl": training_perplexity,
        }
        return {"loss": train_loss, "log": tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids, input_mask, labels = batch
        log_probs = self(input_ids=input_ids, attention_mask=input_mask)

        val_loss = self.validation_loss(log_probs=log_probs, labels=labels)
        self.validation_perplexity(logits=log_probs)

        tensorboard_logs = {"val_loss": val_loss}

        return {"val_loss": val_loss, "log": tensorboard_logs}

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """

        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        validation_perplexity = self.validation_perplexity.compute()
        tensorboard_logs = {
            "val_loss": avg_loss,
            "val_ppl": validation_perplexity
        }
        return {"val_loss": avg_loss, "log": tensorboard_logs}

    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config)

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config,
            predict_last_k=self.dataset_cfg.get("predict_last_k", 0),
        )

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config,
            predict_last_k=self.dataset_cfg.get("predict_last_k", 0),
        )

    def _setup_dataloader_from_config(self, cfg: DictConfig, predict_last_k=0):
        dataset = L2RLanguageModelingDataset(
            tokenizer=self.tokenizer,
            dataset=cfg.file_name,
            max_seq_length=self.dataset_cfg.max_seq_length,
            batch_step=predict_last_k,
        )
        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=cfg.batch_size,
            shuffle=cfg.shuffle,
            num_workers=self.dataset_cfg.get("num_workers", 2),
            pin_memory=self.dataset_cfg.get("pin_memory", False),
            drop_last=self.dataset_cfg.get("drop_last", False),
        )

    @classmethod
    def list_available_models(cls) -> Optional[Dict[str, str]]:
        pass