class TransformerLMModel(ModelPT): """ Left-to-right Transformer language model. """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable self.world_size = 1 if trainer is not None: self.world_size = trainer.num_nodes * trainer.num_gpus cfg = model_utils.convert_model_config_to_dict_config(cfg) cfg = model_utils.maybe_update_config_version(cfg) # Instantiates tokenizer and register to be saved with NeMo Model archive # After this call, ther will be self.tokenizer which can convert between tokens and token_ids. self.setup_tokenizer( tokenizer_name=cfg.tokenizer.get("tokenizer_name", "yttm"), tokenizer_model=cfg.tokenizer.get("tokenizer_model", None), vocab_file=cfg.tokenizer.get("vocab_file", None), bpe_dropout=cfg.tokenizer.get("bpe_dropout", 0.0), special_tokens=cfg.tokenizer.get("special_tokens", {})) # init superclass super().__init__(cfg=cfg, trainer=trainer) # make vocabulary size divisible by 8 for fast fp16 training vocab_size = 8 * math.ceil(self.tokenizer.vocab_size / 8) # encoder from NeMo, Megatron-LM, or HuggingFace encoder_cfg_dict = OmegaConf.to_container(cfg.get('encoder')) encoder_cfg_dict['vocab_size'] = vocab_size library = encoder_cfg_dict.pop('library', 'nemo') model_name = encoder_cfg_dict.pop('model_name', None) pretrained = encoder_cfg_dict.pop('pretrained', False) self.encoder = get_transformer( library=library, model_name=model_name, pretrained=pretrained, config_dict=encoder_cfg_dict, encoder=True, pre_ln_final_layer_norm=encoder_cfg_dict.get( 'pre_ln_final_layer_norm', encoder_cfg_dict.get('pre_ln', True)), ) self.log_softmax = TokenClassifier( hidden_size=self.encoder.hidden_size, num_classes=vocab_size, activation=cfg.head.activation, log_softmax=cfg.head.log_softmax, dropout=cfg.head.dropout, use_transformer_init=cfg.head.use_transformer_init, ) # tie weights of embedding and softmax matrices self.log_softmax.mlp.layer0.weight = self.encoder.embedding.token_embedding.weight std_init_range = 1 / self.encoder.hidden_size**0.5 # initialize weights if not using pretrained encoder if not self._cfg.encoder.get('pretrained', False): self.encoder.apply(lambda module: transformer_weights_init( module, std_init_range)) self.log_softmax.apply( lambda module: transformer_weights_init(module, std_init_range)) self.loss_fn = SmoothedCrossEntropyLoss( pad_id=self.tokenizer.pad_id, label_smoothing=cfg.label_smoothing) self.eval_loss_fn = SmoothedCrossEntropyLoss( pad_id=self.tokenizer.pad_id) self.eval_loss = GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True) self.eval_ppl = SequencePerplexity() @typecheck() def forward(self, input_ids, attention_mask): """ No special modification required for Lightning, define it as you normally would in the `nn.Module` in vanilla PyTorch. """ hidden_states = self.encoder(input_ids=input_ids, encoder_mask=attention_mask) log_probs = self.log_softmax(hidden_states=hidden_states) return log_probs def training_step(self, batch, batch_idx): """ Lightning calls this inside the training loop with the data from the training dataloader passed in as `batch`. """ # forward pass for i in range(len(batch)): if batch[i].ndim == 3: # Dataset returns already batched data and the first dimension of size 1 # added by DataLoader is excess. batch[i] = batch[i].squeeze(dim=0) ids, mask = batch input_ids, labels = ids[:, :-1], ids[:, 1:] input_mask = mask[:, :-1] log_probs = self(input_ids=input_ids, attention_mask=input_mask) train_loss = self.loss_fn(log_probs=log_probs, labels=labels) tensorboard_logs = { "train_loss": train_loss, "lr": self._optimizer.param_groups[0]["lr"], } return {"loss": train_loss, "log": tensorboard_logs} def eval_step(self, batch, batch_idx): for i in range(len(batch)): if batch[i].ndim == 3: # Dataset returns already batched data and the first dimension of size 1 # added by DataLoader is excess. batch[i] = batch[i].squeeze(dim=0) ids, mask = batch input_ids, labels = ids[:, :-1], ids[:, 1:] input_mask, output_mask = mask[:, :-1], mask[:, 1:] log_probs = self(input_ids=input_ids, attention_mask=input_mask) eval_loss = self.eval_loss_fn(log_probs=log_probs, labels=labels) self.eval_loss(loss=eval_loss, num_measurements=log_probs.shape[0] * log_probs.shape[1]) self.eval_ppl(log_probs=log_probs, labels=labels, mask=output_mask) return {} def test_step(self, batch, batch_idx): return self.eval_step(batch, batch_idx) def validation_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """ return self.eval_step(batch, batch_idx) def eval_epoch_end(self, outputs, mode): eval_loss = self.eval_loss.compute() eval_ppl = self.eval_ppl.compute() self.log(f"{mode}_loss", eval_loss, sync_dist=True) self.log(f"{mode}_PPL", eval_ppl, sync_dist=True) dataset_name = "Validation" if mode == 'val' else "Test" logging.info( f"\n\n\n\n{dataset_name} PPL: {np.round(eval_ppl.item(), 2)}") def validation_epoch_end(self, outputs): """ Called at the end of validation to aggregate outputs. :param outputs: list of individual outputs of each validation step. """ self.eval_epoch_end(outputs, 'val') self.eval_loss.reset() self.eval_ppl.reset() def test_epoch_end(self, outputs): self.eval_epoch_end(outputs, 'test') def setup_tokenizer(self, tokenizer_name=None, tokenizer_model=None, vocab_file=None, bpe_dropout=0.0, special_tokens=None): supported_tokenizers = ['yttm', 'huggingface', 'sentencepiece', 'word'] if tokenizer_name not in supported_tokenizers: raise NotImplementedError( f"Currently we only support tokenizers in {supported_tokenizers}." ) self.tokenizer = get_tokenizer( tokenizer_name=tokenizer_name, tokenizer_model=self.register_artifact( "cfg.tokenizer.tokenizer_model", tokenizer_model), vocab_file=vocab_file, bpe_dropout=bpe_dropout, special_tokens=dict(special_tokens or {}), use_fast=False, ) def setup_training_data(self, train_data_config: Optional[DictConfig]): self._train_dl = self._setup_dataloader_from_config( cfg=train_data_config) def setup_validation_data(self, val_data_config: Optional[DictConfig]): self._validation_dl = self._setup_dataloader_from_config( cfg=val_data_config) def setup_test_data(self, test_data_config: Optional[DictConfig]): self._test_dl = self._setup_dataloader_from_config( cfg=test_data_config) def _setup_dataloader_from_config(self, cfg: DictConfig, predict_last_k=0): if cfg.get("use_tarred_dataset", False): if cfg.get("metadata_file") is None: raise FileNotFoundError( "Trying to use tarred data set but could not find metadata path in config." ) else: metadata_file = cfg.get('metadata_file') with open(metadata_file) as metadata_reader: metadata = json.load(metadata_reader) if cfg.get('tar_files') is None: tar_files = metadata.get('tar_files') if tar_files is not None: logging.info( f'Loading from tarred dataset {tar_files}') else: raise FileNotFoundError( "Could not find tarred dataset in config or metadata." ) else: tar_files = cfg.get('tar_files') if metadata.get('tar_files') is not None: raise ValueError( 'Tar files specified in config and in metadata file. Tar files should only be specified once.' ) dataset = TarredSentenceDataset( text_tar_filepaths=tar_files, metadata_path=metadata_file, tokenizer=self.tokenizer, shuffle_n=cfg.get("tar_shuffle_n", 100), shard_strategy=cfg.get("shard_strategy", "scatter"), global_rank=self.global_rank, world_size=self.world_size, ) return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), ) else: dataset = SentenceDataset( tokenizer=self.tokenizer, dataset=cfg.file_name, tokens_in_batch=cfg.tokens_in_batch, clean=cfg.get("clean", False), max_seq_length=cfg.get("max_seq_length", 512), min_seq_length=cfg.get("min_seq_length", 1), cache_ids=cfg.get("cache_ids", False), ) if cfg.shuffle: sampler = pt_data.RandomSampler(dataset) else: sampler = pt_data.SequentialSampler(dataset) return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, sampler=sampler, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), ) @classmethod def list_available_models(cls) -> Optional[Dict[str, str]]: pass
def _loss_class_test( rank: int, worldsize: int, loss_sum_or_avg: Optional[torch.Tensor], num_measurements: Optional[torch.Tensor], dist_sync_on_step: bool, take_avg_loss: bool, check_dist_sync_on_step: bool = True, check_batch: bool = True, atol: float = 1e-8, ): """ Utility function doing the actual comparison between lightning class metric and reference metric. Args: rank: rank of current process worldsize: number of processes loss_sum_or_avg: a one dimensional float torch tensor with loss sums or means. num_measurements: a one dimensional integer torch tensor with number of values on which sums or means from ``loss_sum_or_avg`` were computed. dist_sync_on_step: bool, if true will synchronize metric state across processes at each call of the method :meth:`forward()` take_avg_loss: dict with additional arguments used for class initialization check_dist_sync_on_step: bool, if true will check if the metric is also correctly calculated per batch per device (and not just at the end) check_batch: bool, if true will check if the metric is also correctly calculated across devices for each batch (and not just at the end) """ # Instantiate lightning metric loss_metric = GlobalAverageLossMetric(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, take_avg_loss=take_avg_loss) # verify loss works after being loaded from pickled state pickled_metric = pickle.dumps(loss_metric) loss_metric = pickle.loads(pickled_metric) for i in range(rank, NUM_BATCHES, worldsize): batch_result = loss_metric(loss_sum_or_avg[i], num_measurements[i]) if loss_metric.dist_sync_on_step: if rank == 0: ddp_loss_sum_or_avg = torch.stack( [loss_sum_or_avg[i + r] for r in range(worldsize)]) ddp_num_measurements = torch.stack( [num_measurements[i + r] for r in range(worldsize)]) sk_batch_result = reference_loss_func(ddp_loss_sum_or_avg, ddp_num_measurements, take_avg_loss) # assert for dist_sync_on_step if check_dist_sync_on_step: if sk_batch_result.isnan(): assert batch_result.isnan() else: assert np.allclose( batch_result.numpy(), sk_batch_result, atol=atol ), f"batch_result = {batch_result.numpy()}, sk_batch_result = {sk_batch_result}, i = {i}" else: ls = loss_sum_or_avg[i:i + 1] nm = num_measurements[i:i + 1] sk_batch_result = reference_loss_func(ls, nm, take_avg_loss) # assert for batch if check_batch: if sk_batch_result.isnan(): assert batch_result.isnan() else: assert np.allclose( batch_result.numpy(), sk_batch_result, atol=atol ), f"batch_result = {batch_result.numpy()}, sk_batch_result = {sk_batch_result}, i = {i}" # check on all batches on all ranks result = loss_metric.compute() assert isinstance(result, torch.Tensor) sk_result = reference_loss_func(loss_sum_or_avg, num_measurements, take_avg_loss) # assert after aggregation if sk_result.isnan(): assert result.isnan() else: assert np.allclose( result.numpy(), sk_result, atol=atol), f"result = {result.numpy()}, sk_result = {sk_result}"
class MTEncDecModel(EncDecNLPModel): """ Encoder-decoder machine translation model. """ def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None): cfg = model_utils.convert_model_config_to_dict_config(cfg) # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 self.world_size = 1 if trainer is not None: self.world_size = trainer.num_nodes * trainer.num_gpus cfg = model_utils.maybe_update_config_version(cfg) self.src_language: str = cfg.get("src_language", None) self.tgt_language: str = cfg.get("tgt_language", None) # Instantiates tokenizers and register to be saved with NeMo Model archive # After this call, ther will be self.encoder_tokenizer and self.decoder_tokenizer # Which can convert between tokens and token_ids for SRC and TGT languages correspondingly. self.setup_enc_dec_tokenizers( encoder_tokenizer_name=cfg.encoder_tokenizer.tokenizer_name, encoder_tokenizer_model=cfg.encoder_tokenizer.tokenizer_model, encoder_bpe_dropout=cfg.encoder_tokenizer.get('bpe_dropout', 0.0), decoder_tokenizer_name=cfg.decoder_tokenizer.tokenizer_name, decoder_tokenizer_model=cfg.decoder_tokenizer.tokenizer_model, decoder_bpe_dropout=cfg.decoder_tokenizer.get('bpe_dropout', 0.0), ) # After this call, the model will have self.source_processor and self.target_processor objects self.setup_pre_and_post_processing_utils(source_lang=self.src_language, target_lang=self.tgt_language) # TODO: Why is this base constructor call so late in the game? super().__init__(cfg=cfg, trainer=trainer) # TODO: use get_encoder function with support for HF and Megatron self.encoder = TransformerEncoderNM( vocab_size=self.encoder_vocab_size, hidden_size=cfg.encoder.hidden_size, num_layers=cfg.encoder.num_layers, inner_size=cfg.encoder.inner_size, max_sequence_length=cfg.encoder.max_sequence_length if hasattr( cfg.encoder, 'max_sequence_length') else 512, embedding_dropout=cfg.encoder.embedding_dropout if hasattr( cfg.encoder, 'embedding_dropout') else 0.0, learn_positional_encodings=cfg.encoder.learn_positional_encodings if hasattr(cfg.encoder, 'learn_positional_encodings') else False, num_attention_heads=cfg.encoder.num_attention_heads, ffn_dropout=cfg.encoder.ffn_dropout, attn_score_dropout=cfg.encoder.attn_score_dropout, attn_layer_dropout=cfg.encoder.attn_layer_dropout, hidden_act=cfg.encoder.hidden_act, mask_future=cfg.encoder.mask_future, pre_ln=cfg.encoder.pre_ln, ) # TODO: user get_decoder function with support for HF and Megatron self.decoder = TransformerDecoderNM( vocab_size=self.decoder_vocab_size, hidden_size=cfg.decoder.hidden_size, num_layers=cfg.decoder.num_layers, inner_size=cfg.decoder.inner_size, max_sequence_length=cfg.decoder.max_sequence_length if hasattr( cfg.decoder, 'max_sequence_length') else 512, embedding_dropout=cfg.decoder.embedding_dropout if hasattr( cfg.decoder, 'embedding_dropout') else 0.0, learn_positional_encodings=cfg.decoder.learn_positional_encodings if hasattr(cfg.decoder, 'learn_positional_encodings') else False, num_attention_heads=cfg.decoder.num_attention_heads, ffn_dropout=cfg.decoder.ffn_dropout, attn_score_dropout=cfg.decoder.attn_score_dropout, attn_layer_dropout=cfg.decoder.attn_layer_dropout, hidden_act=cfg.decoder.hidden_act, pre_ln=cfg.decoder.pre_ln, ) self.log_softmax = TokenClassifier( hidden_size=self.decoder.hidden_size, num_classes=self.decoder_vocab_size, activation=cfg.head.activation, log_softmax=cfg.head.log_softmax, dropout=cfg.head.dropout, use_transformer_init=cfg.head.use_transformer_init, ) self.beam_search = BeamSearchSequenceGenerator( embedding=self.decoder.embedding, decoder=self.decoder.decoder, log_softmax=self.log_softmax, max_sequence_length=self.decoder.max_sequence_length, beam_size=cfg.beam_size, bos=self.decoder_tokenizer.bos_id, pad=self.decoder_tokenizer.pad_id, eos=self.decoder_tokenizer.eos_id, len_pen=cfg.len_pen, max_delta_length=cfg.max_generation_delta, ) # tie weights of embedding and softmax matrices self.log_softmax.mlp.layer0.weight = self.decoder.embedding.token_embedding.weight # TODO: encoder and decoder with different hidden size? std_init_range = 1 / self.encoder.hidden_size**0.5 self.apply( lambda module: transformer_weights_init(module, std_init_range)) self.loss_fn = SmoothedCrossEntropyLoss( pad_id=self.decoder_tokenizer.pad_id, label_smoothing=cfg.label_smoothing) self.eval_loss = GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True) def filter_predicted_ids(self, ids): ids[ids >= self.decoder_tokenizer.vocab_size] = self.decoder_tokenizer.unk_id return ids @typecheck() def forward(self, src, src_mask, tgt, tgt_mask): src_hiddens = self.encoder(src, src_mask) tgt_hiddens = self.decoder(tgt, tgt_mask, src_hiddens, src_mask) log_probs = self.log_softmax(hidden_states=tgt_hiddens) return log_probs def training_step(self, batch, batch_idx): """ Lightning calls this inside the training loop with the data from the training dataloader passed in as `batch`. """ # forward pass for i in range(len(batch)): if batch[i].ndim == 3: # Dataset returns already batched data and the first dimension of size 1 added by DataLoader # is excess. batch[i] = batch[i].squeeze(dim=0) src_ids, src_mask, tgt_ids, tgt_mask, labels = batch log_probs = self(src_ids, src_mask, tgt_ids, tgt_mask) train_loss = self.loss_fn(log_probs=log_probs, labels=labels) tensorboard_logs = { 'train_loss': train_loss, 'lr': self._optimizer.param_groups[0]['lr'], } return {'loss': train_loss, 'log': tensorboard_logs} def eval_step(self, batch, batch_idx, mode): for i in range(len(batch)): if batch[i].ndim == 3: # Dataset returns already batched data and the first dimension of size 1 added by DataLoader # is excess. batch[i] = batch[i].squeeze(dim=0) src_ids, src_mask, tgt_ids, tgt_mask, labels = batch log_probs = self(src_ids, src_mask, tgt_ids, tgt_mask) # this will run encoder twice -- TODO: potentially fix _, translations = self.batch_translate(src=src_ids, src_mask=src_mask) eval_loss = self.loss_fn(log_probs=log_probs, labels=labels) self.eval_loss(loss=eval_loss, num_measurements=log_probs.shape[0] * log_probs.shape[1]) np_tgt = tgt_ids.cpu().numpy() ground_truths = [ self.decoder_tokenizer.ids_to_text(tgt) for tgt in np_tgt ] ground_truths = [ self.target_processor.detokenize(tgt.split(' ')) for tgt in ground_truths ] num_non_pad_tokens = np.not_equal( np_tgt, self.decoder_tokenizer.pad_id).sum().item() return { 'translations': translations, 'ground_truths': ground_truths, 'num_non_pad_tokens': num_non_pad_tokens, } def test_step(self, batch, batch_idx): return self.eval_step(batch, batch_idx, 'test') @rank_zero_only def log_param_stats(self): for name, p in self.named_parameters(): if p.requires_grad: self.trainer.logger.experiment.add_histogram( name + '_hist', p, global_step=self.global_step) self.trainer.logger.experiment.add_scalars( name, { 'mean': p.mean(), 'stddev': p.std(), 'max': p.max(), 'min': p.min() }, global_step=self.global_step, ) def validation_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """ return self.eval_step(batch, batch_idx, 'val') def eval_epoch_end(self, outputs, mode): eval_loss = self.eval_loss.compute() translations = list( itertools.chain(*[x['translations'] for x in outputs])) ground_truths = list( itertools.chain(*[x['ground_truths'] for x in outputs])) assert len(translations) == len(ground_truths) if self.tgt_language in ['ja']: sacre_bleu = corpus_bleu(translations, [ground_truths], tokenize="ja-mecab") elif self.tgt_language in ['zh']: sacre_bleu = corpus_bleu(translations, [ground_truths], tokenize="zh") else: sacre_bleu = corpus_bleu(translations, [ground_truths], tokenize="13a") dataset_name = "Validation" if mode == 'val' else "Test" logging.info(f"\n\n\n\n{dataset_name} set size: {len(translations)}") logging.info(f"{dataset_name} Sacre BLEU = {sacre_bleu.score}") logging.info(f"{dataset_name} TRANSLATION EXAMPLES:".upper()) for i in range(0, 3): ind = random.randint(0, len(translations) - 1) logging.info(" " + '\u0332'.join(f"EXAMPLE {i}:")) logging.info(f" Prediction: {translations[ind]}") logging.info(f" Ground Truth: {ground_truths[ind]}") ans = { f"{mode}_loss": eval_loss, f"{mode}_sacreBLEU": sacre_bleu.score } ans['log'] = dict(ans) return ans def validation_epoch_end(self, outputs): """ Called at the end of validation to aggregate outputs. :param outputs: list of individual outputs of each validation step. """ self.log_dict(self.eval_epoch_end(outputs, 'val')) def test_epoch_end(self, outputs): return self.eval_epoch_end(outputs, 'test') def setup_training_data(self, train_data_config: Optional[DictConfig]): self._train_dl = self._setup_dataloader_from_config( cfg=train_data_config) def setup_validation_data(self, val_data_config: Optional[DictConfig]): self._validation_dl = self._setup_dataloader_from_config( cfg=val_data_config) def setup_test_data(self, test_data_config: Optional[DictConfig]): self._test_dl = self._setup_dataloader_from_config( cfg=test_data_config) def _setup_dataloader_from_config(self, cfg: DictConfig): if cfg.get("load_from_cached_dataset", False): logging.info('Loading from cached dataset %s' % (cfg.src_file_name)) if cfg.src_file_name != cfg.tgt_file_name: raise ValueError( "src must be equal to target for cached dataset") dataset = pickle.load(open(cfg.src_file_name, 'rb')) dataset.reverse_lang_direction = cfg.get("reverse_lang_direction", False) elif cfg.get("use_tarred_dataset", False): if cfg.get('tar_files') is None: raise FileNotFoundError("Could not find tarred dataset.") logging.info(f'Loading from tarred dataset {cfg.get("tar_files")}') if cfg.get("metadata_file", None) is None: raise FileNotFoundError( "Could not find metadata path in config") dataset = TarredTranslationDataset( text_tar_filepaths=cfg.tar_files, metadata_path=cfg.metadata_file, encoder_tokenizer=self.encoder_tokenizer, decoder_tokenizer=self.decoder_tokenizer, shuffle_n=cfg.get("tar_shuffle_n", 100), shard_strategy=cfg.get("shard_strategy", "scatter"), global_rank=self.global_rank, world_size=self.world_size, reverse_lang_direction=cfg.get("reverse_lang_direction", False), ) return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), ) else: dataset = TranslationDataset( dataset_src=str(Path(cfg.src_file_name).expanduser()), dataset_tgt=str(Path(cfg.tgt_file_name).expanduser()), tokens_in_batch=cfg.tokens_in_batch, clean=cfg.get("clean", False), max_seq_length=cfg.get("max_seq_length", 512), min_seq_length=cfg.get("min_seq_length", 1), max_seq_length_diff=cfg.get("max_seq_length_diff", 512), max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512), cache_ids=cfg.get("cache_ids", False), cache_data_per_node=cfg.get("cache_data_per_node", False), use_cache=cfg.get("use_cache", False), reverse_lang_direction=cfg.get("reverse_lang_direction", False), ) dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer) if cfg.shuffle: sampler = pt_data.RandomSampler(dataset) else: sampler = pt_data.SequentialSampler(dataset) return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, sampler=sampler, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), ) def setup_pre_and_post_processing_utils(self, source_lang, target_lang): """ Creates source and target processor objects for input and output pre/post-processing. """ self.source_processor, self.target_processor = None, None if (source_lang == 'en' and target_lang == 'ja') or (source_lang == 'ja' and target_lang == 'en'): self.source_processor = EnJaProcessor(source_lang) self.target_processor = EnJaProcessor(target_lang) else: if source_lang == 'zh': self.source_processor = ChineseProcessor() if target_lang == 'zh': self.target_processor = ChineseProcessor() if source_lang is not None and source_lang not in ['ja', 'zh']: self.source_processor = MosesProcessor(source_lang) if target_lang is not None and target_lang not in ['ja', 'zh']: self.target_processor = MosesProcessor(target_lang) @torch.no_grad() def batch_translate( self, src: torch.LongTensor, src_mask: torch.LongTensor, ): """ Translates a minibatch of inputs from source language to target language. Args: src: minibatch of inputs in the src language (batch x seq_len) src_mask: mask tensor indicating elements to be ignored (batch x seq_len) Returns: translations: a list strings containing detokenized translations inputs: a list of string containing detokenized inputs """ mode = self.training try: self.eval() src_hiddens = self.encoder(input_ids=src, encoder_mask=src_mask) beam_results = self.beam_search(encoder_hidden_states=src_hiddens, encoder_input_mask=src_mask) beam_results = self.filter_predicted_ids(beam_results) translations = [ self.decoder_tokenizer.ids_to_text(tr) for tr in beam_results.cpu().numpy() ] inputs = [ self.encoder_tokenizer.ids_to_text(inp) for inp in src.cpu().numpy() ] if self.target_processor is not None: translations = [ self.target_processor.detokenize(translation.split(' ')) for translation in translations ] if self.source_processor is not None: inputs = [ self.source_processor.detokenize(item.split(' ')) for item in inputs ] finally: self.train(mode=mode) return inputs, translations # TODO: We should drop source/target_lang arguments in favor of using self.src/tgt_language @torch.no_grad() def translate(self, text: List[str], source_lang: str = None, target_lang: str = None) -> List[str]: """ Translates list of sentences from source language to target language. Should be regular text, this method performs its own tokenization/de-tokenization Args: text: list of strings to translate source_lang: if not None, corresponding MosesTokenizer and MosesPunctNormalizer will be run target_lang: if not None, corresponding MosesDecokenizer will be run Returns: list of translated strings """ # __TODO__: This will reset both source and target processors even if you want to reset just one. if source_lang is not None or target_lang is not None: self.setup_pre_and_post_processing_utils(source_lang, target_lang) mode = self.training try: self.eval() inputs = [] for txt in text: if self.source_processor is not None: txt = self.source_processor.normalize(txt) txt = self.source_processor.tokenize(txt) ids = self.encoder_tokenizer.text_to_ids(txt) ids = [self.encoder_tokenizer.bos_id ] + ids + [self.encoder_tokenizer.eos_id] inputs.append(ids) max_len = max(len(txt) for txt in inputs) src_ids_ = np.ones( (len(inputs), max_len)) * self.encoder_tokenizer.pad_id for i, txt in enumerate(inputs): src_ids_[i][:len(txt)] = txt src_mask = torch.FloatTensor( (src_ids_ != self.encoder_tokenizer.pad_id)).to(self.device) src = torch.LongTensor(src_ids_).to(self.device) _, translations = self.batch_translate(src, src_mask) finally: self.train(mode=mode) return translations @classmethod def list_available_models(cls) -> Optional[Dict[str, str]]: pass
class MTEncDecModel(EncDecNLPModel): """ Encoder-decoder machine translation model. """ def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None): cfg = model_utils.convert_model_config_to_dict_config(cfg) # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable self.global_rank = 0 self.world_size = 1 if trainer is not None: self.global_rank = (trainer.node_rank * trainer.num_gpus) + trainer.local_rank self.world_size = trainer.num_nodes * trainer.num_gpus cfg = model_utils.maybe_update_config_version(cfg) self.setup_enc_dec_tokenizers(cfg) super().__init__(cfg=cfg, trainer=trainer) self.src_language: str = cfg.get("src_language", None) self.tgt_language: str = cfg.get("tgt_language", None) # TODO: use get_encoder function with support for HF and Megatron self.encoder = TransformerEncoderNM( vocab_size=self.encoder_vocab_size, hidden_size=cfg.encoder.hidden_size, num_layers=cfg.encoder.num_layers, inner_size=cfg.encoder.inner_size, max_sequence_length=cfg.encoder.max_sequence_length if hasattr(cfg.encoder, 'max_sequence_length') else 512, embedding_dropout=cfg.encoder.embedding_dropout if hasattr(cfg.encoder, 'embedding_dropout') else 0.0, learn_positional_encodings=cfg.encoder.learn_positional_encodings if hasattr(cfg.encoder, 'learn_positional_encodings') else False, num_attention_heads=cfg.encoder.num_attention_heads, ffn_dropout=cfg.encoder.ffn_dropout, attn_score_dropout=cfg.encoder.attn_score_dropout, attn_layer_dropout=cfg.encoder.attn_layer_dropout, hidden_act=cfg.encoder.hidden_act, mask_future=cfg.encoder.mask_future, pre_ln=cfg.encoder.pre_ln, ) # TODO: user get_decoder function with support for HF and Megatron self.decoder = TransformerDecoderNM( vocab_size=self.decoder_vocab_size, hidden_size=cfg.decoder.hidden_size, num_layers=cfg.decoder.num_layers, inner_size=cfg.decoder.inner_size, max_sequence_length=cfg.decoder.max_sequence_length if hasattr(cfg.decoder, 'max_sequence_length') else 512, embedding_dropout=cfg.decoder.embedding_dropout if hasattr(cfg.decoder, 'embedding_dropout') else 0.0, learn_positional_encodings=cfg.decoder.learn_positional_encodings if hasattr(cfg.decoder, 'learn_positional_encodings') else False, num_attention_heads=cfg.decoder.num_attention_heads, ffn_dropout=cfg.decoder.ffn_dropout, attn_score_dropout=cfg.decoder.attn_score_dropout, attn_layer_dropout=cfg.decoder.attn_layer_dropout, hidden_act=cfg.decoder.hidden_act, pre_ln=cfg.decoder.pre_ln, ) self.log_softmax = TokenClassifier( hidden_size=self.decoder.hidden_size, num_classes=self.decoder_vocab_size, activation=cfg.head.activation, log_softmax=cfg.head.log_softmax, dropout=cfg.head.dropout, use_transformer_init=cfg.head.use_transformer_init, ) self.beam_search = BeamSearchSequenceGenerator( embedding=self.decoder.embedding, decoder=self.decoder.decoder, log_softmax=self.log_softmax, max_sequence_length=self.decoder.max_sequence_length, beam_size=cfg.beam_size, bos=self.decoder_tokenizer.bos_id, pad=self.decoder_tokenizer.pad_id, eos=self.decoder_tokenizer.eos_id, len_pen=cfg.len_pen, max_delta_length=cfg.max_generation_delta, ) # tie weights of embedding and softmax matrices self.log_softmax.mlp.layer0.weight = self.decoder.embedding.token_embedding.weight # TODO: encoder and decoder with different hidden size? std_init_range = 1 / self.encoder.hidden_size ** 0.5 self.apply(lambda module: transformer_weights_init(module, std_init_range)) self.loss_fn = SmoothedCrossEntropyLoss( pad_id=self.decoder_tokenizer.pad_id, label_smoothing=cfg.label_smoothing ) self.eval_loss = GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True) def filter_predicted_ids(self, ids): ids[ids >= self.decoder_tokenizer.vocab_size] = self.decoder_tokenizer.unk_id return ids @typecheck() def forward(self, src, src_mask, tgt, tgt_mask): src_hiddens = self.encoder(src, src_mask) tgt_hiddens = self.decoder(tgt, tgt_mask, src_hiddens, src_mask) log_probs = self.log_softmax(hidden_states=tgt_hiddens) return log_probs def training_step(self, batch, batch_idx): """ Lightning calls this inside the training loop with the data from the training dataloader passed in as `batch`. """ # forward pass for i in range(len(batch)): if batch[i].ndim == 3: # Dataset returns already batched data and the first dimension of size 1 added by DataLoader # is excess. batch[i] = batch[i].squeeze(dim=0) src_ids, src_mask, tgt_ids, tgt_mask, labels = batch log_probs = self(src_ids, src_mask, tgt_ids, tgt_mask) train_loss = self.loss_fn(log_probs=log_probs, labels=labels) tensorboard_logs = { 'train_loss': train_loss, 'lr': self._optimizer.param_groups[0]['lr'], } return {'loss': train_loss, 'log': tensorboard_logs} def eval_step(self, batch, batch_idx, mode): for i in range(len(batch)): if batch[i].ndim == 3: # Dataset returns already batched data and the first dimension of size 1 added by DataLoader # is excess. batch[i] = batch[i].squeeze(dim=0) src_ids, src_mask, tgt_ids, tgt_mask, labels = batch log_probs = self(src_ids, src_mask, tgt_ids, tgt_mask) src_hiddens = self.encoder(src_ids, src_mask) beam_results = self.beam_search(encoder_hidden_states=src_hiddens, encoder_input_mask=src_mask) beam_results = self.filter_predicted_ids(beam_results) eval_loss = self.loss_fn(log_probs=log_probs, labels=labels) self.eval_loss(loss=eval_loss, num_measurements=log_probs.shape[0] * log_probs.shape[1]) translations = [self.decoder_tokenizer.ids_to_text(tr) for tr in beam_results.cpu().numpy()] np_tgt = tgt_ids.cpu().numpy() ground_truths = [self.decoder_tokenizer.ids_to_text(tgt) for tgt in np_tgt] num_non_pad_tokens = np.not_equal(np_tgt, self.decoder_tokenizer.pad_id).sum().item() return { 'translations': translations, 'ground_truths': ground_truths, 'num_non_pad_tokens': num_non_pad_tokens, } def test_step(self, batch, batch_idx): return self.eval_step(batch, batch_idx, 'test') @rank_zero_only def log_param_stats(self): for name, p in self.named_parameters(): if p.requires_grad: self.trainer.logger.experiment.add_histogram(name + '_hist', p, global_step=self.global_step) self.trainer.logger.experiment.add_scalars( name, {'mean': p.mean(), 'stddev': p.std(), 'max': p.max(), 'min': p.min()}, global_step=self.global_step, ) def validation_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """ return self.eval_step(batch, batch_idx, 'val') def eval_epoch_end(self, outputs, mode): eval_loss = self.eval_loss.compute() translations = list(itertools.chain(*[x['translations'] for x in outputs])) ground_truths = list(itertools.chain(*[x['ground_truths'] for x in outputs])) # TODO: add target language so detokenizer can be lang specific. detokenizer = MosesDetokenizer(lang=self.tgt_language) translations = [detokenizer.detokenize(sent.split()) for sent in translations] ground_truths = [detokenizer.detokenize(sent.split()) for sent in ground_truths] if self.tgt_language in ['ja']: sp_detokenizer = SentencePieceDetokenizer() translations = [sp_detokenizer.detokenize(sent.split()) for sent in translations] ground_truths = [sp_detokenizer.detokenize(sent.split()) for sent in ground_truths] assert len(translations) == len(ground_truths) if self.tgt_language in ['ja']: sacre_bleu = corpus_bleu(translations, [ground_truths], tokenize="ja-mecab") else: sacre_bleu = corpus_bleu(translations, [ground_truths], tokenize="13a") dataset_name = "Validation" if mode == 'val' else "Test" logging.info(f"\n\n\n\n{dataset_name} set size: {len(translations)}") logging.info(f"{dataset_name} Sacre BLEU = {sacre_bleu.score}") logging.info(f"{dataset_name} TRANSLATION EXAMPLES:".upper()) for i in range(0, 3): ind = random.randint(0, len(translations) - 1) logging.info(" " + '\u0332'.join(f"EXAMPLE {i}:")) logging.info(f" Prediction: {translations[ind]}") logging.info(f" Ground Truth: {ground_truths[ind]}") ans = {f"{mode}_loss": eval_loss, f"{mode}_sacreBLEU": sacre_bleu.score} ans['log'] = dict(ans) return ans def validation_epoch_end(self, outputs): """ Called at the end of validation to aggregate outputs. :param outputs: list of individual outputs of each validation step. """ self.log_dict(self.eval_epoch_end(outputs, 'val')) def test_epoch_end(self, outputs): return self.eval_epoch_end(outputs, 'test') def setup_training_data(self, train_data_config: Optional[DictConfig]): self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config) def setup_validation_data(self, val_data_config: Optional[DictConfig]): self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config) def setup_test_data(self, test_data_config: Optional[DictConfig]): self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config) def _setup_dataloader_from_config(self, cfg: DictConfig): if cfg.get("load_from_cached_dataset", False): logging.info('Loading from cached dataset %s' % (cfg.src_file_name)) if cfg.src_file_name != cfg.tgt_file_name: raise ValueError("src must be equal to target for cached dataset") dataset = pickle.load(open(cfg.src_file_name, 'rb')) dataset.reverse_lang_direction = cfg.get("reverse_lang_direction", False) elif cfg.get("load_from_tarred_dataset", False): logging.info('Loading from tarred dataset %s' % (cfg.src_file_name)) if cfg.src_file_name != cfg.tgt_file_name: raise ValueError("src must be equal to target for tarred dataset") if cfg.get("metadata_path", None) is None: raise FileNotFoundError("Could not find metadata path in config") dataset = TarredTranslationDataset( text_tar_filepaths=cfg.src_file_name, metadata_path=cfg.metadata_path, encoder_tokenizer=self.encoder_tokenizer, decoder_tokenizer=self.decoder_tokenizer, shuffle_n=cfg.get("tar_shuffle_n", 100), shard_strategy=cfg.get("shard_strategy", "scatter"), global_rank=self.global_rank, world_size=self.world_size, reverse_lang_direction=cfg.get("reverse_lang_direction", False), ) return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), ) else: dataset = TranslationDataset( dataset_src=str(Path(cfg.src_file_name).expanduser()), dataset_tgt=str(Path(cfg.tgt_file_name).expanduser()), tokens_in_batch=cfg.tokens_in_batch, clean=cfg.get("clean", False), max_seq_length=cfg.get("max_seq_length", 512), min_seq_length=cfg.get("min_seq_length", 1), max_seq_length_diff=cfg.get("max_seq_length_diff", 512), max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512), cache_ids=cfg.get("cache_ids", False), cache_data_per_node=cfg.get("cache_data_per_node", False), use_cache=cfg.get("use_cache", False), reverse_lang_direction=cfg.get("reverse_lang_direction", False), ) dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer) if cfg.shuffle: sampler = pt_data.RandomSampler(dataset) else: sampler = pt_data.SequentialSampler(dataset) return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, sampler=sampler, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), ) @torch.no_grad() def translate(self, text: List[str], source_lang: str = None, target_lang: str = None) -> List[str]: """ Translates list of sentences from source language to target language. Should be regular text, this method performs its own tokenization/de-tokenization Args: text: list of strings to translate source_lang: if not None, corresponding MosesTokenizer and MosesPunctNormalizer will be run target_lang: if not None, corresponding MosesDecokenizer will be run Returns: list of translated strings """ if source_lang is None: source_lang = self.src_language if target_lang is None: target_lang = self.tgt_language mode = self.training tokenizer = MosesTokenizer(lang=source_lang) normalizer = MosesPunctNormalizer(lang=source_lang) detokenizer = MosesDetokenizer(lang=target_lang) try: self.eval() res = [] for txt in text: if source_lang != "None": txt = normalizer.normalize(txt) txt = tokenizer.tokenize(txt, escape=False, return_str=True) ids = self.encoder_tokenizer.text_to_ids(txt) ids = [self.encoder_tokenizer.bos_id] + ids + [self.encoder_tokenizer.eos_id] src = torch.Tensor(ids).long().to(self._device).unsqueeze(0) src_mask = torch.ones_like(src) src_hiddens = self.encoder(input_ids=src, encoder_mask=src_mask) beam_results = self.beam_search(encoder_hidden_states=src_hiddens, encoder_input_mask=src_mask) beam_results = self.filter_predicted_ids(beam_results) translation_ids = beam_results.cpu()[0].numpy() translation = self.decoder_tokenizer.ids_to_text(translation_ids) translation = detokenizer.detokenize(translation.split()) if target_lang in ["ja"]: sp_detokenizer = SentencePieceDetokenizer() translation = sp_detokenizer.detokenize(translation.split()) res.append(translation) finally: self.train(mode=mode) return res @classmethod def list_available_models(cls) -> Optional[Dict[str, str]]: pass def configure_ddp(self, model: LightningModule, device_ids: List[int]) -> DistributedDataParallel: logging.info(f'overriding ddp to set find_unused_parameters to {self._cfg.find_unused_parameters}') model = LightningDistributedDataParallel( model, device_ids=device_ids, find_unused_parameters=self._cfg.find_unused_parameters ) return model def setup(self, stage): if stage == "fit": # Update PTL trainer to use our configure_ddp self._trainer.accelerator_backend.ddp_plugin.configure_ddp = self.configure_ddp