class NeuralMachineTranslationModel(ModelPT): @property def input_types(self) -> Optional[Dict[str, NeuralType]]: return { "input_ids": NeuralType(('B', 'T'), ChannelType()), "attention_mask": NeuralType(('B', 'T'), MaskType(), optional=True), "decoder_input_ids": NeuralType(('B', 'T'), ChannelType(), optional=True), "labels": NeuralType(('B', 'T'), ChannelType(), optional=True), } @property def output_types(self) -> Optional[Dict[str, NeuralType]]: return { "loss": NeuralType((), LossType()), "decoder_hidden_states": NeuralType(("B", "T", "D"), ChannelType(), optional=True), "encoder_hidden_states": NeuralType(("B", "T", "D"), ChannelType(), optional=True), } def __init__(self, cfg: DictConfig, trainer: Trainer = None): # must assign tokenizers before init if cfg.language_model.pretrained_model_name: if cfg.language_model.pretrained_encoder_model_name or cfg.language_model.pretrained_decoder_model_name: raise ValueError( "Must have either pretrained_model_name or both pretrained_encoder_model name and " "pretrained_decoder_model_name.") # setup tokenizer self.encoder_tokenizer = self.setup_tokenizer( cfg.encoder_tokenizer) self.encoder_add_special_tokens = cfg.encoder_tokenizer.add_special_tokens # set decoder to encoder self.decoder_tokenizer = self.encoder_tokenizer self.decoder_add_special_tokens = self.encoder_add_special_tokens else: if not (cfg.language_model.pretrained_encoder_model_name and cfg.language_model.pretrained_decoder_model_name): raise ValueError("Both encoder and decoder must be specified") # setup tokenizers self.encoder_tokenizer = self.setup_tokenizer( cfg.encoder_tokenizer) self.encoder_add_special_tokens = cfg.encoder_tokenizer.add_special_tokens self.decoder_tokenizer = self.setup_tokenizer( cfg.decoder_tokenizer) self.decoder_add_special_tokens = cfg.decoder_tokenizer.add_special_tokens if not self.encoder_tokenizer: raise TypeError("encoder_tokenizer failed to initialize") if not self.decoder_tokenizer: raise TypeError("decoder_tokenizer failed to initialize") # init superclass super().__init__(cfg=cfg, trainer=trainer) # must assign modules after init if cfg.language_model.pretrained_model_name: # Setup end-to-end model if "bart" in cfg.language_model.pretrained_model_name: self.model = BartForConditionalGeneration.from_pretrained( cfg.language_model.pretrained_model_name) else: self.model = AutoModel.from_pretrained( cfg.language_model.pretrained_model_name) else: if not (cfg.language_model.pretrained_encoder_model_name and cfg.language_model.pretrained_decoder_model_name): raise ValueError("Both encoder and decoder must be specified") # Setup encoder/decoder model self.model = EncoderDecoderModel.from_encoder_decoder_pretrained( encoder=cfg.language_model.pretrained_encoder_model_name, decoder=cfg.language_model.pretrained_decoder_model_name, ) self.validation_perplexity = Perplexity(compute_on_step=False) self.setup_optimization(cfg.optim) @typecheck() def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None, decoder_input_ids: torch.Tensor = None, labels: torch.Tensor = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ No special modification required for Lightning, define it as you normally would in the `nn.Module` in vanilla PyTorch. """ outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, labels=labels, return_dict=False, ) return outputs @typecheck.disable_checks() def generate(self, input_ids: torch.Tensor) -> torch.Tensor: """Wraps huggingface EncoderDecoder.generate().""" outputs = self.model.generate( input_ids=input_ids, pad_token_id=self.encoder_tokenizer.pad_id, bos_token_id=self.encoder_tokenizer.bos_id, eos_token_id=self.encoder_tokenizer.eos_id, decoder_start_token_id=self.decoder_tokenizer.bos_id, **self._cfg.generate, ) return outputs def training_step(self, batch: Tuple, batch_idx: int) -> Dict: """ Lightning calls this inside the training loop with the data from the training dataloader passed in as `batch`. Loss calculation from HuggingFace's BartForConditionalGeneration. """ input_ids, input_mask, decoder_input_ids, labels = batch loss = self.forward( input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=decoder_input_ids, labels=labels, )[0] tensorboard_logs = { "train_loss": loss, "lr": self._optimizer.param_groups[0]["lr"] } return {"loss": loss, "log": tensorboard_logs} def validation_step(self, batch: Tuple, batch_idx: int) -> Dict: """ Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. Loss calculation from HuggingFace's BartForConditionalGeneration. """ input_ids, input_mask, decoder_input_ids, labels = batch loss, logits = self.forward( input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=decoder_input_ids, labels=labels, )[:2] self.validation_perplexity(logits=logits) tensorboard_logs = {"val_loss": loss} return {"val_loss": loss, "log": tensorboard_logs} def validation_epoch_end(self, outputs: List[Dict]) -> Dict: """ Called at the end of validation to aggregate outputs. :param outputs: list of individual outputs of each validation step. """ avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() perplexity = self.validation_perplexity.compute() tensorboard_logs = {"val_loss": avg_loss, "perplexity": perplexity} logging.info(f"evaluation perplexity {perplexity.item()}") return {"val_loss": avg_loss, "log": tensorboard_logs} @typecheck.disable_checks() def test_step(self, batch: Tuple, batch_idx: int) -> torch.Tensor: """Lightning calls this inside the test loop with data from the test dataloader.""" input_ids, input_mask, decoder_input_ids, labels = batch sequences = self.generate(input_ids=input_ids) return sequences @typecheck.disable_checks() def test_epoch_end(self, outputs: List[torch.Tensor]) -> Dict[str, List[str]]: """Called at the end of test to aggregate outputs and decode them.""" texts = [ self.encoder_tokenizer.ids_to_text(seq) for batch in outputs for seq in batch ] return {"texts": texts} def setup_tokenizer(self, cfg: DictConfig): tokenizer = get_tokenizer( tokenizer_name=cfg.tokenizer_name, tokenizer_model=cfg.tokenizer_model, special_tokens=OmegaConf.to_container(cfg.special_tokens) if cfg.special_tokens else None, vocab_file=cfg.vocab_file, ) return tokenizer def setup_training_data(self, train_data_config: Optional[DictConfig]): self._train_dl = self.setup_dataloader_from_config( cfg=train_data_config) def setup_validation_data(self, val_data_config: Optional[DictConfig]): self._validation_dl = self.setup_dataloader_from_config( cfg=val_data_config) def setup_test_data(self, test_data_config: Optional[DictConfig]): self._test_dl = self.setup_dataloader_from_config(cfg=test_data_config) def setup_dataloader_from_config(self, cfg: DictConfig): dataset = NeuralMachineTranslationDataset( filepath=cfg.filepath, encoder_tokenizer=self.encoder_tokenizer, decoder_tokenizer=self.decoder_tokenizer, encoder_add_special_tokens=self.encoder_add_special_tokens, decoder_add_special_tokens=self.decoder_add_special_tokens, max_seq_length=self._cfg.max_seq_length, num_samples=cfg.get("num_samples", -1), convert_labels=self._cfg.convert_labels, ) return torch.utils.data.DataLoader( dataset=dataset, batch_size=self._cfg.batch_size, shuffle=cfg.shuffle, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), collate_fn=dataset.collate_fn, ) @classmethod def list_available_models(cls) -> Optional[Dict[str, str]]: pass
def _perplexity_class_test( rank: int, worldsize: int, probs: Optional[torch.Tensor], logits: Optional[torch.Tensor], dist_sync_on_step: bool, metric_args: dict = {}, check_dist_sync_on_step: bool = True, check_batch: bool = True, atol: float = 1e-8, ): """ Utility function doing the actual comparison between lightning class metric and reference metric. Args: rank: rank of current process worldsize: number of processes probs: torch tensor with probabilities logits: torch tensor with logits. The function checks ``probs`` and ``logits are mutually exclusive for ``Perplexity`` metric. dist_sync_on_step: bool, if true will synchronize metric state across processes at each ``forward()`` metric_args: dict with additional arguments used for class initialization check_dist_sync_on_step: bool, if true will check if the metric is also correctly calculated per batch per device (and not just at the end) check_batch: bool, if true will check if the metric is also correctly calculated across devices for each batch (and not just at the end) """ # Instanciate lightning metric perplexity = Perplexity(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) if (probs is None) == (logits is None): with pytest.raises(ValueError): perplexity(probs, logits) return # verify perplexity works after being loaded from pickled state pickled_metric = pickle.dumps(perplexity) perplexity = pickle.loads(pickled_metric) for i in range(rank, NUM_BATCHES, worldsize): batch_result = perplexity(None if probs is None else probs[i], None if logits is None else logits[i]) if perplexity.dist_sync_on_step: if rank == 0: if probs is not None: ddp_probs = torch.stack( [probs[i + r] for r in range(worldsize)]) else: ddp_logits = torch.stack( [logits[i + r] for r in range(worldsize)]) ddp_probs = logits_to_probs(ddp_logits, is_binary=False) sk_batch_result = reference_perplexity_func(ddp_probs) # assert for dist_sync_on_step if check_dist_sync_on_step: assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol) else: if probs is None: p = logits_to_probs(logits[i], is_binary=False) else: p = probs[i] sk_batch_result = reference_perplexity_func(p) # assert for batch if check_batch: assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol) assert (probs is None) != (logits is None) # check on all batches on all ranks result = perplexity.compute() assert isinstance(result, torch.Tensor) if probs is None: probs = logits_to_probs(logits, is_binary=False) sk_result = reference_perplexity_func(probs) # assert after aggregation assert np.allclose(result.numpy(), sk_result, atol=atol)
class TransformerLMModel(ModelPT): """ Left-to-right Transformer language model. """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable self.global_rank = 0 self.world_size = 1 if trainer is not None: self.global_rank = (trainer.node_rank * trainer.num_gpus) + trainer.local_rank self.world_size = trainer.num_nodes * trainer.num_gpus # shared params for dataset and data loaders self.dataset_cfg = cfg.dataset self.tokenizer = get_tokenizer( tokenizer_name=cfg.language_model.tokenizer, vocab_file=cfg.language_model.vocab_file, special_tokens=cfg.language_model.special_tokens, ) # make vocabulary size divisible by 8 for fast fp16 training vocab_size = 8 * math.ceil(self.tokenizer.vocab_size / 8) # init superclass super().__init__(cfg=cfg, trainer=trainer) self.embedding_layer = TransformerEmbedding( vocab_size=vocab_size, hidden_size=cfg.language_model.hidden_size, max_sequence_length=cfg.language_model.max_seq_length, embedding_dropout=cfg.language_model.get("embedding_dropout", 0.0), learn_positional_encodings=False, ) self.encoder = TransformerEncoder( num_layers=cfg.language_model.num_layers, hidden_size=cfg.language_model.hidden_size, mask_future=True, num_attention_heads=cfg.language_model.num_attn_heads, inner_size=cfg.language_model.inner_size, ffn_dropout=cfg.language_model.get("ffn_dropout", 0.0), hidden_act=cfg.language_model.get("inner_activation", "relu"), attn_score_dropout=cfg.language_model.get("attn_score_dropout", 0.0), attn_layer_dropout=cfg.language_model.get("attn_layer_dropout", 0.0), ) self.log_softmax = TokenClassifier( hidden_size=cfg.language_model.hidden_size, num_classes=vocab_size, log_softmax=True, ) std_init_range = 1 / math.sqrt(cfg.language_model.hidden_size) self.apply( lambda module: transformer_weights_init(module, std_init_range)) # tie weights of embedding and softmax matrices self.log_softmax.mlp.layer0.weight = self.embedding_layer.token_embedding.weight self.training_loss = SmoothedCrossEntropyLoss( pad_id=self.tokenizer.pad_id) self.validation_loss = SmoothedCrossEntropyLoss( pad_id=self.tokenizer.pad_id, predict_last_k=self.dataset_cfg.get("predict_last_k", 0), ) self.training_perplexity = Perplexity(dist_sync_on_step=True) self.validation_perplexity = Perplexity(compute_on_step=False) # Optimizer setup needs to happen after all model weights are ready self.setup_optimization(cfg.optim) @typecheck() def forward(self, input_ids, attention_mask): """ No special modification required for Lightning, define it as you normally would in the `nn.Module` in vanilla PyTorch. """ token_embeddings = self.embedding_layer(input_ids) hidden_states = self.encoder(token_embeddings, attention_mask) log_probs = self.log_softmax(hidden_states=hidden_states) return log_probs def training_step(self, batch, batch_idx): """ Lightning calls this inside the training loop with the data from the training dataloader passed in as `batch`. """ # forward pass input_ids, input_mask, labels = batch log_probs = self(input_ids=input_ids, attention_mask=input_mask) train_loss = self.training_loss(log_probs=log_probs, labels=labels) training_perplexity = self.training_perplexity(logits=log_probs) tensorboard_logs = { "train_loss": train_loss, "lr": self._optimizer.param_groups[0]["lr"], "train_ppl": training_perplexity, } return {"loss": train_loss, "log": tensorboard_logs} def validation_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """ input_ids, input_mask, labels = batch log_probs = self(input_ids=input_ids, attention_mask=input_mask) val_loss = self.validation_loss(log_probs=log_probs, labels=labels) self.validation_perplexity(logits=log_probs) tensorboard_logs = {"val_loss": val_loss} return {"val_loss": val_loss, "log": tensorboard_logs} def validation_epoch_end(self, outputs): """ Called at the end of validation to aggregate outputs. :param outputs: list of individual outputs of each validation step. """ avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() validation_perplexity = self.validation_perplexity.compute() tensorboard_logs = { "val_loss": avg_loss, "val_ppl": validation_perplexity } return {"val_loss": avg_loss, "log": tensorboard_logs} def setup_training_data(self, train_data_config: Optional[DictConfig]): self._train_dl = self._setup_dataloader_from_config( cfg=train_data_config) # Need to set this because if using an IterableDataset, the length of the dataloader is the total number # of samples rather than the number of batches, and this messes up the tqdm progress bar. # So we set the number of steps manually (to the correct number) to fix this. if 'is_tarred' in train_data_config and train_data_config['is_tarred']: # We also need to check if limit_train_batches is already set. # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). if isinstance(self._trainer.limit_train_batches, float): self._trainer.limit_train_batches = int( self._trainer.limit_train_batches * math.ceil( (len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size'])) def setup_validation_data(self, val_data_config: Optional[DictConfig]): self._validation_dl = self._setup_dataloader_from_config( cfg=val_data_config, predict_last_k=self.dataset_cfg.get("predict_last_k", 0), ) def setup_test_data(self, test_data_config: Optional[DictConfig]): self._test_dl = self._setup_dataloader_from_config( cfg=test_data_config, predict_last_k=self.dataset_cfg.get("predict_last_k", 0), ) def _setup_dataloader_from_config(self, cfg: DictConfig, predict_last_k=0): if cfg.get('use_cache', False): logging.info("Constructing tokenized dataset cache...") shuffle = cfg.shuffle if cfg.get('is_tarred', False): if ('tarred_text_filepaths' in cfg and cfg['tarred_text_filepaths'] is None) or ( 'file_name' in cfg and cfg['file_name'] is None): logging.warning( "Could not load dataset as `file_name` was None or " f"`tarred_text_filepaths` is None. Provided config : {cfg}" ) return None shuffle_n = cfg.get('shuffle_n', 4 * cfg['batch_size']) if shuffle else 0 dataset = TarredL2RLanguageModelingDataset( text_tar_filepaths=cfg['tarred_text_filepaths'], metadata_path=cfg['file_name'], tokenizer=self.tokenizer, max_seq_length=self.dataset_cfg.max_seq_length, batch_step=predict_last_k, shuffle_n=shuffle_n, shard_strategy=cfg.get("tarred_shard_strategy", "scatter"), global_rank=self.global_rank, world_size=self.world_size, ) shuffle = False else: dataset = L2RLanguageModelingDataset( tokenizer=self.tokenizer, dataset=cfg.file_name, max_seq_length=self.dataset_cfg.max_seq_length, batch_step=predict_last_k, use_cache=cfg.get('use_cache', False), ) return torch.utils.data.DataLoader( dataset=dataset, batch_size=cfg.batch_size, shuffle=shuffle, num_workers=self.dataset_cfg.get("num_workers", 2), pin_memory=self.dataset_cfg.get("pin_memory", False), drop_last=self.dataset_cfg.get("drop_last", False), ) @classmethod def list_available_models(cls) -> Optional[Dict[str, str]]: pass
class BERTLMModel(ModelPT): """ BERT language model pretraining. """ @property def input_types(self) -> Optional[Dict[str, NeuralType]]: return self.bert_model.input_types @property def output_types(self) -> Optional[Dict[str, NeuralType]]: output_types_dict = { "mlm_log_probs": self.mlm_classifier.output_types["log_probs"] } if not self.only_mlm_loss: output_types_dict["nsp_logits"] = self.nsp_classifier.output_types[ "logits"] return output_types_dict def __init__(self, cfg: DictConfig, trainer: Trainer = None): if cfg.tokenizer is not None: self._setup_tokenizer(cfg.tokenizer) else: self.tokenizer = None super().__init__(cfg=cfg, trainer=trainer) self.bert_model = get_lm_model( pretrained_model_name=cfg.language_model.pretrained_model_name, config_file=cfg.language_model.config_file, config_dict=OmegaConf.to_container(cfg.language_model.config) if cfg.language_model.config else None, checkpoint_file=cfg.language_model.lm_checkpoint, ) self.hidden_size = self.bert_model.config.hidden_size self.vocab_size = self.bert_model.config.vocab_size self.only_mlm_loss = cfg.only_mlm_loss self.mlm_classifier = BertPretrainingTokenClassifier( hidden_size=self.hidden_size, num_classes=self.vocab_size, num_layers=cfg.num_tok_classification_layers, activation="gelu", log_softmax=True, use_transformer_init=True, ) self.mlm_loss = SmoothedCrossEntropyLoss() if not self.only_mlm_loss: self.nsp_classifier = SequenceClassifier( hidden_size=self.hidden_size, num_classes=2, num_layers=cfg.num_seq_classification_layers, log_softmax=False, activation="tanh", use_transformer_init=True, ) self.nsp_loss = CrossEntropyLoss() self.agg_loss = AggregatorLoss(num_inputs=2) # # tie weights of MLM softmax layer and embedding layer of the encoder if (self.mlm_classifier.mlp.last_linear_layer.weight.shape != self.bert_model.embeddings.word_embeddings.weight.shape): raise ValueError( "Final classification layer does not match embedding layer.") self.mlm_classifier.mlp.last_linear_layer.weight = self.bert_model.embeddings.word_embeddings.weight # create extra bias # setup to track metrics self.validation_perplexity = Perplexity(compute_on_step=False) self.setup_optimization(cfg.optim) @typecheck() def forward(self, input_ids, token_type_ids, attention_mask): """ No special modification required for Lightning, define it as you normally would in the `nn.Module` in vanilla PyTorch. """ hidden_states = self.bert_model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, ) mlm_log_probs = self.mlm_classifier(hidden_states=hidden_states) if self.only_mlm_loss: return (mlm_log_probs, ) nsp_logits = self.nsp_classifier(hidden_states=hidden_states) return mlm_log_probs, nsp_logits def _compute_losses(self, mlm_log_probs, nsp_logits, output_ids, output_mask, labels): mlm_loss = self.mlm_loss(log_probs=mlm_log_probs, labels=output_ids, output_mask=output_mask) if self.only_mlm_loss: loss, nsp_loss = mlm_loss, None else: nsp_loss = self.nsp_loss(logits=nsp_logits, labels=labels) loss = self.agg_loss(loss_1=mlm_loss, loss_2=nsp_loss) return mlm_loss, nsp_loss, loss def _parse_forward_outputs(self, forward_outputs): if self.only_mlm_loss: return forward_outputs[0], None else: return forward_outputs def training_step(self, batch, batch_idx): """ Lightning calls this inside the training loop with the data from the training dataloader passed in as `batch`. """ # forward pass input_ids, input_type_ids, input_mask, output_ids, output_mask, labels = batch forward_outputs = self.forward(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask) mlm_log_probs, nsp_logits = self._parse_forward_outputs( forward_outputs) _, _, loss = self._compute_losses(mlm_log_probs, nsp_logits, output_ids, output_mask, labels) tensorboard_logs = {"train_loss": loss} return {"loss": loss, "log": tensorboard_logs} def validation_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """ input_ids, input_type_ids, input_mask, output_ids, output_mask, labels = batch forward_outputs = self.forward(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask) mlm_log_probs, nsp_logits = self._parse_forward_outputs( forward_outputs) _, _, loss = self._compute_losses(mlm_log_probs, nsp_logits, output_ids, output_mask, labels) self.validation_perplexity(logits=mlm_log_probs) tensorboard_logs = {'val_loss': loss} return {'val_loss': loss, 'log': tensorboard_logs} def validation_epoch_end(self, outputs): """Called at the end of validation to aggregate outputs. Args: outputs (list): The individual outputs of each validation step. Returns: dict: Validation loss and tensorboard logs. """ if outputs: avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() perplexity = self.validation_perplexity.compute() tensorboard_logs = {'val_loss': avg_loss, 'perplexity': perplexity} logging.info(f"evaluation perplexity {perplexity.cpu().item()}") return {'val_loss': avg_loss, 'log': tensorboard_logs} def setup_training_data(self, train_data_config: Optional[DictConfig]): self._train_dl = ( self._setup_preprocessed_dataloader(train_data_config) if self.tokenizer is None else self._setup_dataloader(train_data_config)) def setup_validation_data(self, val_data_config: Optional[DictConfig]): self._validation_dl = ( self._setup_preprocessed_dataloader(val_data_config) if self.tokenizer is None else self._setup_dataloader(val_data_config)) def setup_test_data(self, test_data_config: Optional[DictConfig]): pass def _setup_preprocessed_dataloader(self, cfg: Optional[DictConfig]): dataset = cfg.data_file max_predictions_per_seq = cfg.max_predictions_per_seq batch_size = cfg.batch_size if os.path.isdir(dataset): files = [ os.path.join(dataset, f) for f in os.listdir(dataset) if os.path.isfile(os.path.join(dataset, f)) ] else: files = [dataset] files.sort() dl = BertPretrainingPreprocessedDataloader( data_files=files, max_predictions_per_seq=max_predictions_per_seq, batch_size=batch_size, ) return dl def _setup_tokenizer(self, cfg: DictConfig): tokenizer = get_tokenizer( tokenizer_name=cfg.tokenizer_name, tokenizer_model=cfg.tokenizer_model, special_tokens=OmegaConf.to_container(cfg.special_tokens) if cfg.special_tokens else None, vocab_file=cfg.vocab_file, ) self.tokenizer = tokenizer def _setup_dataloader(self, cfg: DictConfig): dataset = BertPretrainingDataset( tokenizer=self.tokenizer, data_file=cfg.data_file, max_seq_length=cfg.max_seq_length, mask_prob=cfg.mask_prob, short_seq_prob=cfg.short_seq_prob, ) dl = torch.utils.data.DataLoader( dataset=dataset, batch_size=cfg.batch_size, collate_fn=dataset.collate_fn, drop_last=cfg.get("drop_last", False), shuffle=cfg.shuffle, num_workers=cfg.get("num_workers", 0), ) return dl @classmethod def list_available_models(cls) -> Optional[Dict[str, str]]: pass
class BERTLMModel(ModelPT): """ BERT language model pretraining. """ @property def input_types(self) -> Optional[Dict[str, NeuralType]]: return self.bert_model.input_types @property def output_types(self) -> Optional[Dict[str, NeuralType]]: output_types_dict = { "mlm_log_probs": self.mlm_classifier.output_types["log_probs"] } if not self.only_mlm_loss: output_types_dict["nsp_logits"] = self.nsp_classifier.output_types[ "logits"] return output_types_dict def __init__(self, cfg: DictConfig, trainer: Trainer = None): vocab_file = None config_dict = None config_file = None if cfg.tokenizer is not None: if cfg.tokenizer.get('tokenizer_name') and cfg.tokenizer.get( 'tokenizer_model'): self._setup_tokenizer(cfg.tokenizer) if cfg.get('tokenizer.vocab_file'): vocab_file = self.register_artifact('tokenizer.vocab_file', cfg.tokenizer.vocab_file) else: self.tokenizer = None super().__init__(cfg=cfg, trainer=trainer) if cfg.get('language_model.config'): config_dict = OmegaConf.to_container(cfg.language_model.config) if cfg.get('language_model.config_file'): config_file = self.register_artifact( 'language_model.config_file', cfg.language_model.config_file) self.bert_model = get_lm_model( config_file=config_file, config_dict=config_dict, vocab_file=vocab_file, trainer=trainer, cfg=cfg, ) self.hidden_size = self.bert_model.config.hidden_size self.vocab_size = self.bert_model.config.vocab_size self.only_mlm_loss = cfg.only_mlm_loss self.mlm_classifier = BertPretrainingTokenClassifier( hidden_size=self.hidden_size, num_classes=self.vocab_size, num_layers=cfg.num_tok_classification_layers, activation="gelu", log_softmax=True, use_transformer_init=True, ) self.mlm_loss = SmoothedCrossEntropyLoss() if not self.only_mlm_loss: self.nsp_classifier = SequenceClassifier( hidden_size=self.hidden_size, num_classes=2, num_layers=cfg.num_seq_classification_layers, log_softmax=False, activation="tanh", use_transformer_init=True, ) self.nsp_loss = CrossEntropyLoss() self.agg_loss = AggregatorLoss(num_inputs=2) # # tie weights of MLM softmax layer and embedding layer of the encoder if (self.mlm_classifier.mlp.last_linear_layer.weight.shape != self.bert_model.embeddings.word_embeddings.weight.shape): raise ValueError( "Final classification layer does not match embedding layer.") self.mlm_classifier.mlp.last_linear_layer.weight = self.bert_model.embeddings.word_embeddings.weight # create extra bias # setup to track metrics self.validation_perplexity = Perplexity(compute_on_step=False) self.setup_optimization(cfg.optim) @typecheck() def forward(self, input_ids, attention_mask, token_type_ids): """ No special modification required for Lightning, define it as you normally would in the `nn.Module` in vanilla PyTorch. """ hidden_states = self.bert_model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, ) if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] mlm_log_probs = self.mlm_classifier(hidden_states=hidden_states) if self.only_mlm_loss: return (mlm_log_probs, ) nsp_logits = self.nsp_classifier(hidden_states=hidden_states) return mlm_log_probs, nsp_logits def _compute_losses(self, mlm_log_probs, nsp_logits, output_ids, output_mask, labels): mlm_loss = self.mlm_loss(log_probs=mlm_log_probs, labels=output_ids, output_mask=output_mask) if self.only_mlm_loss: loss, nsp_loss = mlm_loss, None else: nsp_loss = self.nsp_loss(logits=nsp_logits, labels=labels) loss = self.agg_loss(loss_1=mlm_loss, loss_2=nsp_loss) return mlm_loss, nsp_loss, loss def _parse_forward_outputs(self, forward_outputs): if self.only_mlm_loss: return forward_outputs[0], None else: return forward_outputs def training_step(self, batch, batch_idx): """ Lightning calls this inside the training loop with the data from the training dataloader passed in as `batch`. """ input_ids, input_type_ids, input_mask, output_ids, output_mask, labels = batch forward_outputs = self.forward(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask) mlm_log_probs, nsp_logits = self._parse_forward_outputs( forward_outputs) _, _, loss = self._compute_losses(mlm_log_probs, nsp_logits, output_ids, output_mask, labels) lr = self._optimizer.param_groups[0]['lr'] self.log('train_loss', loss) self.log('lr', lr, prog_bar=True) return {"loss": loss, "lr": lr} def validation_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """ input_ids, input_type_ids, input_mask, output_ids, output_mask, labels = batch forward_outputs = self.forward(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask) mlm_log_probs, nsp_logits = self._parse_forward_outputs( forward_outputs) _, _, loss = self._compute_losses(mlm_log_probs, nsp_logits, output_ids, output_mask, labels) self.validation_perplexity(logits=mlm_log_probs) return {'val_loss': loss} def validation_epoch_end(self, outputs): """Called at the end of validation to aggregate outputs. Args: outputs (list): The individual outputs of each validation step. Returns: dict: Validation loss and tensorboard logs. """ if outputs: avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() perplexity = self.validation_perplexity.compute() logging.info(f"evaluation perplexity {perplexity.cpu().item()}") self.log(f'val_loss', avg_loss) def setup_training_data(self, train_data_config: Optional[DictConfig]): self._train_dl = ( self._setup_preprocessed_dataloader(train_data_config) if self.tokenizer is None else self._setup_dataloader(train_data_config)) def setup_validation_data(self, val_data_config: Optional[DictConfig]): self._validation_dl = ( self._setup_preprocessed_dataloader(val_data_config) if self.tokenizer is None else self._setup_dataloader(val_data_config)) def setup_test_data(self, test_data_config: Optional[DictConfig]): pass def _setup_preprocessed_dataloader(self, cfg: Optional[DictConfig]): dataset = cfg.data_file max_predictions_per_seq = cfg.max_predictions_per_seq batch_size = cfg.batch_size if os.path.isdir(dataset): files = [ os.path.join(dataset, f) for f in os.listdir(dataset) if os.path.isfile(os.path.join(dataset, f)) ] else: files = [dataset] files.sort() dl = BertPretrainingPreprocessedDataloader( data_files=files, max_predictions_per_seq=max_predictions_per_seq, batch_size=batch_size, ) return dl def _setup_tokenizer(self, cfg: DictConfig): tokenizer = get_tokenizer( tokenizer_name=cfg.tokenizer_name, tokenizer_model=cfg.tokenizer_model, special_tokens=OmegaConf.to_container(cfg.special_tokens) if cfg.special_tokens else None, vocab_file=cfg.vocab_file, ) self.tokenizer = tokenizer def _setup_dataloader(self, cfg: DictConfig): dataset = BertPretrainingDataset( tokenizer=self.tokenizer, data_file=cfg.data_file, max_seq_length=cfg.max_seq_length, mask_prob=cfg.mask_prob, short_seq_prob=cfg.short_seq_prob, ) dl = torch.utils.data.DataLoader( dataset=dataset, batch_size=cfg.batch_size, collate_fn=dataset.collate_fn, drop_last=cfg.get("drop_last", False), shuffle=cfg.shuffle, num_workers=cfg.get("num_workers", 0), ) return dl @classmethod def list_available_models(cls) -> Optional[PretrainedModelInfo]: """ This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. Returns: List of available pre-trained models. """ result = [] result.append( PretrainedModelInfo( pretrained_model_name="bertbaseuncased", location= "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/bertbaseuncased/versions/1.0.0rc1/files/bertbaseuncased.nemo", description= "The model was trained EN Wikipedia and BookCorpus on a sequence length of 512.", )) result.append( PretrainedModelInfo( pretrained_model_name="bertlargeuncased", location= "https://api.ngc.nvidia.com/v2/models/nvidia/nemo/bertlargeuncased/versions/1.0.0rc1/files/bertlargeuncased.nemo", description= "The model was trained EN Wikipedia and BookCorpus on a sequence length of 512.", )) return result
class TransformerLMModel(ModelPT): """ Left-to-right Transformer language model. """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # shared params for dataset and data loaders self.dataset_cfg = cfg.dataset self.tokenizer = get_tokenizer( tokenizer_name=cfg.language_model.tokenizer, vocab_file=cfg.language_model.vocab_file, special_tokens=cfg.language_model.special_tokens, ) # make vocabulary size divisible by 8 for fast fp16 training vocab_size = 8 * math.ceil(self.tokenizer.vocab_size / 8) # init superclass super().__init__(cfg=cfg, trainer=trainer) self.embedding_layer = TransformerEmbedding( vocab_size=vocab_size, hidden_size=cfg.language_model.hidden_size, max_sequence_length=cfg.language_model.max_seq_length, embedding_dropout=cfg.language_model.get("embedding_dropout", 0.0), learn_positional_encodings=False, ) self.encoder = TransformerEncoder( num_layers=cfg.language_model.num_layers, hidden_size=cfg.language_model.hidden_size, mask_future=True, num_attention_heads=cfg.language_model.num_attn_heads, inner_size=cfg.language_model.inner_size, ffn_dropout=cfg.language_model.get("ffn_dropout", 0.0), hidden_act=cfg.language_model.get("inner_activation", "relu"), attn_score_dropout=cfg.language_model.get("attn_score_dropout", 0.0), attn_layer_dropout=cfg.language_model.get("attn_layer_dropout", 0.0), ) self.log_softmax = TokenClassifier( hidden_size=cfg.language_model.hidden_size, num_classes=vocab_size, log_softmax=True, ) std_init_range = 1 / math.sqrt(cfg.language_model.hidden_size) self.apply( lambda module: transformer_weights_init(module, std_init_range)) # tie weights of embedding and softmax matrices self.log_softmax.mlp.layer0.weight = self.embedding_layer.token_embedding.weight self.training_loss = SmoothedCrossEntropyLoss( pad_id=self.tokenizer.pad_id) self.validation_loss = SmoothedCrossEntropyLoss( pad_id=self.tokenizer.pad_id, predict_last_k=self.dataset_cfg.get("predict_last_k", 0), ) self.training_perplexity = Perplexity(dist_sync_on_step=True) self.validation_perplexity = Perplexity(compute_on_step=False) # Optimizer setup needs to happen after all model weights are ready self.setup_optimization(cfg.optim) @typecheck() def forward(self, input_ids, attention_mask): """ No special modification required for Lightning, define it as you normally would in the `nn.Module` in vanilla PyTorch. """ token_embeddings = self.embedding_layer(input_ids) hidden_states = self.encoder(token_embeddings, attention_mask) log_probs = self.log_softmax(hidden_states=hidden_states) return log_probs def training_step(self, batch, batch_idx): """ Lightning calls this inside the training loop with the data from the training dataloader passed in as `batch`. """ # forward pass input_ids, input_mask, labels = batch log_probs = self(input_ids=input_ids, attention_mask=input_mask) train_loss = self.training_loss(log_probs=log_probs, labels=labels) training_perplexity = self.training_perplexity(logits=log_probs) tensorboard_logs = { "train_loss": train_loss, "lr": self._optimizer.param_groups[0]["lr"], "train_ppl": training_perplexity, } return {"loss": train_loss, "log": tensorboard_logs} def validation_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """ input_ids, input_mask, labels = batch log_probs = self(input_ids=input_ids, attention_mask=input_mask) val_loss = self.validation_loss(log_probs=log_probs, labels=labels) self.validation_perplexity(logits=log_probs) tensorboard_logs = {"val_loss": val_loss} return {"val_loss": val_loss, "log": tensorboard_logs} def validation_epoch_end(self, outputs): """ Called at the end of validation to aggregate outputs. :param outputs: list of individual outputs of each validation step. """ avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() validation_perplexity = self.validation_perplexity.compute() tensorboard_logs = { "val_loss": avg_loss, "val_ppl": validation_perplexity } return {"val_loss": avg_loss, "log": tensorboard_logs} def setup_training_data(self, train_data_config: Optional[DictConfig]): self._train_dl = self._setup_dataloader_from_config( cfg=train_data_config) def setup_validation_data(self, val_data_config: Optional[DictConfig]): self._validation_dl = self._setup_dataloader_from_config( cfg=val_data_config, predict_last_k=self.dataset_cfg.get("predict_last_k", 0), ) def setup_test_data(self, test_data_config: Optional[DictConfig]): self._test_dl = self._setup_dataloader_from_config( cfg=test_data_config, predict_last_k=self.dataset_cfg.get("predict_last_k", 0), ) def _setup_dataloader_from_config(self, cfg: DictConfig, predict_last_k=0): dataset = L2RLanguageModelingDataset( tokenizer=self.tokenizer, dataset=cfg.file_name, max_seq_length=self.dataset_cfg.max_seq_length, batch_step=predict_last_k, ) return torch.utils.data.DataLoader( dataset=dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle, num_workers=self.dataset_cfg.get("num_workers", 2), pin_memory=self.dataset_cfg.get("pin_memory", False), drop_last=self.dataset_cfg.get("drop_last", False), ) @classmethod def list_available_models(cls) -> Optional[Dict[str, str]]: pass