def __init__(self, w_words, s_words, direction, lang): # moses tokenization before LM tokenization # e.g., don't -> don 't, 12/3 -> 12 / 3 processor = MosesProcessor(lang_id=lang) # Build input_words and labels input_words, labels = [], [] # Task Prefix if direction == constants.INST_BACKWARD: input_words.append(constants.ITN_PREFIX) if direction == constants.INST_FORWARD: input_words.append(constants.TN_PREFIX) labels.append(constants.TASK_TAG) # Main Content for w_word, s_word in zip(w_words, s_words): w_word = processor.tokenize(w_word) if not s_word in constants.SPECIAL_WORDS: s_word = processor.tokenize(s_word) # Update input_words and labels if s_word == constants.SIL_WORD and direction == constants.INST_BACKWARD: continue if s_word in constants.SPECIAL_WORDS: input_words.append(w_word) labels.append(constants.SAME_TAG) else: if direction == constants.INST_BACKWARD: input_words.append(s_word) if direction == constants.INST_FORWARD: input_words.append(w_word) labels.append(constants.TRANSFORM_TAG) self.input_words = input_words self.labels = labels
class MTEncDecModel(EncDecNLPModel): """ Encoder-decoder machine translation model. """ def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None): cfg = model_utils.convert_model_config_to_dict_config(cfg) # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 self.world_size = 1 if trainer is not None: self.world_size = trainer.num_nodes * trainer.num_gpus cfg = model_utils.maybe_update_config_version(cfg) self.src_language: str = cfg.get("src_language", None) self.tgt_language: str = cfg.get("tgt_language", None) # Instantiates tokenizers and register to be saved with NeMo Model archive # After this call, ther will be self.encoder_tokenizer and self.decoder_tokenizer # Which can convert between tokens and token_ids for SRC and TGT languages correspondingly. self.setup_enc_dec_tokenizers( encoder_tokenizer_name=cfg.encoder_tokenizer.tokenizer_name, encoder_tokenizer_model=cfg.encoder_tokenizer.tokenizer_model, encoder_bpe_dropout=cfg.encoder_tokenizer.get('bpe_dropout', 0.0), decoder_tokenizer_name=cfg.decoder_tokenizer.tokenizer_name, decoder_tokenizer_model=cfg.decoder_tokenizer.tokenizer_model, decoder_bpe_dropout=cfg.decoder_tokenizer.get('bpe_dropout', 0.0), ) # After this call, the model will have self.source_processor and self.target_processor objects self.setup_pre_and_post_processing_utils(source_lang=self.src_language, target_lang=self.tgt_language) # TODO: Why is this base constructor call so late in the game? super().__init__(cfg=cfg, trainer=trainer) # TODO: use get_encoder function with support for HF and Megatron self.encoder = TransformerEncoderNM( vocab_size=self.encoder_vocab_size, hidden_size=cfg.encoder.hidden_size, num_layers=cfg.encoder.num_layers, inner_size=cfg.encoder.inner_size, max_sequence_length=cfg.encoder.max_sequence_length if hasattr( cfg.encoder, 'max_sequence_length') else 512, embedding_dropout=cfg.encoder.embedding_dropout if hasattr( cfg.encoder, 'embedding_dropout') else 0.0, learn_positional_encodings=cfg.encoder.learn_positional_encodings if hasattr(cfg.encoder, 'learn_positional_encodings') else False, num_attention_heads=cfg.encoder.num_attention_heads, ffn_dropout=cfg.encoder.ffn_dropout, attn_score_dropout=cfg.encoder.attn_score_dropout, attn_layer_dropout=cfg.encoder.attn_layer_dropout, hidden_act=cfg.encoder.hidden_act, mask_future=cfg.encoder.mask_future, pre_ln=cfg.encoder.pre_ln, ) # TODO: user get_decoder function with support for HF and Megatron self.decoder = TransformerDecoderNM( vocab_size=self.decoder_vocab_size, hidden_size=cfg.decoder.hidden_size, num_layers=cfg.decoder.num_layers, inner_size=cfg.decoder.inner_size, max_sequence_length=cfg.decoder.max_sequence_length if hasattr( cfg.decoder, 'max_sequence_length') else 512, embedding_dropout=cfg.decoder.embedding_dropout if hasattr( cfg.decoder, 'embedding_dropout') else 0.0, learn_positional_encodings=cfg.decoder.learn_positional_encodings if hasattr(cfg.decoder, 'learn_positional_encodings') else False, num_attention_heads=cfg.decoder.num_attention_heads, ffn_dropout=cfg.decoder.ffn_dropout, attn_score_dropout=cfg.decoder.attn_score_dropout, attn_layer_dropout=cfg.decoder.attn_layer_dropout, hidden_act=cfg.decoder.hidden_act, pre_ln=cfg.decoder.pre_ln, ) self.log_softmax = TokenClassifier( hidden_size=self.decoder.hidden_size, num_classes=self.decoder_vocab_size, activation=cfg.head.activation, log_softmax=cfg.head.log_softmax, dropout=cfg.head.dropout, use_transformer_init=cfg.head.use_transformer_init, ) self.beam_search = BeamSearchSequenceGenerator( embedding=self.decoder.embedding, decoder=self.decoder.decoder, log_softmax=self.log_softmax, max_sequence_length=self.decoder.max_sequence_length, beam_size=cfg.beam_size, bos=self.decoder_tokenizer.bos_id, pad=self.decoder_tokenizer.pad_id, eos=self.decoder_tokenizer.eos_id, len_pen=cfg.len_pen, max_delta_length=cfg.max_generation_delta, ) # tie weights of embedding and softmax matrices self.log_softmax.mlp.layer0.weight = self.decoder.embedding.token_embedding.weight # TODO: encoder and decoder with different hidden size? std_init_range = 1 / self.encoder.hidden_size**0.5 self.apply( lambda module: transformer_weights_init(module, std_init_range)) self.loss_fn = SmoothedCrossEntropyLoss( pad_id=self.decoder_tokenizer.pad_id, label_smoothing=cfg.label_smoothing) self.eval_loss = GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True) def filter_predicted_ids(self, ids): ids[ids >= self.decoder_tokenizer.vocab_size] = self.decoder_tokenizer.unk_id return ids @typecheck() def forward(self, src, src_mask, tgt, tgt_mask): src_hiddens = self.encoder(src, src_mask) tgt_hiddens = self.decoder(tgt, tgt_mask, src_hiddens, src_mask) log_probs = self.log_softmax(hidden_states=tgt_hiddens) return log_probs def training_step(self, batch, batch_idx): """ Lightning calls this inside the training loop with the data from the training dataloader passed in as `batch`. """ # forward pass for i in range(len(batch)): if batch[i].ndim == 3: # Dataset returns already batched data and the first dimension of size 1 added by DataLoader # is excess. batch[i] = batch[i].squeeze(dim=0) src_ids, src_mask, tgt_ids, tgt_mask, labels = batch log_probs = self(src_ids, src_mask, tgt_ids, tgt_mask) train_loss = self.loss_fn(log_probs=log_probs, labels=labels) tensorboard_logs = { 'train_loss': train_loss, 'lr': self._optimizer.param_groups[0]['lr'], } return {'loss': train_loss, 'log': tensorboard_logs} def eval_step(self, batch, batch_idx, mode): for i in range(len(batch)): if batch[i].ndim == 3: # Dataset returns already batched data and the first dimension of size 1 added by DataLoader # is excess. batch[i] = batch[i].squeeze(dim=0) src_ids, src_mask, tgt_ids, tgt_mask, labels = batch log_probs = self(src_ids, src_mask, tgt_ids, tgt_mask) # this will run encoder twice -- TODO: potentially fix _, translations = self.batch_translate(src=src_ids, src_mask=src_mask) eval_loss = self.loss_fn(log_probs=log_probs, labels=labels) self.eval_loss(loss=eval_loss, num_measurements=log_probs.shape[0] * log_probs.shape[1]) np_tgt = tgt_ids.cpu().numpy() ground_truths = [ self.decoder_tokenizer.ids_to_text(tgt) for tgt in np_tgt ] ground_truths = [ self.target_processor.detokenize(tgt.split(' ')) for tgt in ground_truths ] num_non_pad_tokens = np.not_equal( np_tgt, self.decoder_tokenizer.pad_id).sum().item() return { 'translations': translations, 'ground_truths': ground_truths, 'num_non_pad_tokens': num_non_pad_tokens, } def test_step(self, batch, batch_idx): return self.eval_step(batch, batch_idx, 'test') @rank_zero_only def log_param_stats(self): for name, p in self.named_parameters(): if p.requires_grad: self.trainer.logger.experiment.add_histogram( name + '_hist', p, global_step=self.global_step) self.trainer.logger.experiment.add_scalars( name, { 'mean': p.mean(), 'stddev': p.std(), 'max': p.max(), 'min': p.min() }, global_step=self.global_step, ) def validation_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """ return self.eval_step(batch, batch_idx, 'val') def eval_epoch_end(self, outputs, mode): eval_loss = self.eval_loss.compute() translations = list( itertools.chain(*[x['translations'] for x in outputs])) ground_truths = list( itertools.chain(*[x['ground_truths'] for x in outputs])) assert len(translations) == len(ground_truths) if self.tgt_language in ['ja']: sacre_bleu = corpus_bleu(translations, [ground_truths], tokenize="ja-mecab") elif self.tgt_language in ['zh']: sacre_bleu = corpus_bleu(translations, [ground_truths], tokenize="zh") else: sacre_bleu = corpus_bleu(translations, [ground_truths], tokenize="13a") dataset_name = "Validation" if mode == 'val' else "Test" logging.info(f"\n\n\n\n{dataset_name} set size: {len(translations)}") logging.info(f"{dataset_name} Sacre BLEU = {sacre_bleu.score}") logging.info(f"{dataset_name} TRANSLATION EXAMPLES:".upper()) for i in range(0, 3): ind = random.randint(0, len(translations) - 1) logging.info(" " + '\u0332'.join(f"EXAMPLE {i}:")) logging.info(f" Prediction: {translations[ind]}") logging.info(f" Ground Truth: {ground_truths[ind]}") ans = { f"{mode}_loss": eval_loss, f"{mode}_sacreBLEU": sacre_bleu.score } ans['log'] = dict(ans) return ans def validation_epoch_end(self, outputs): """ Called at the end of validation to aggregate outputs. :param outputs: list of individual outputs of each validation step. """ self.log_dict(self.eval_epoch_end(outputs, 'val')) def test_epoch_end(self, outputs): return self.eval_epoch_end(outputs, 'test') def setup_training_data(self, train_data_config: Optional[DictConfig]): self._train_dl = self._setup_dataloader_from_config( cfg=train_data_config) def setup_validation_data(self, val_data_config: Optional[DictConfig]): self._validation_dl = self._setup_dataloader_from_config( cfg=val_data_config) def setup_test_data(self, test_data_config: Optional[DictConfig]): self._test_dl = self._setup_dataloader_from_config( cfg=test_data_config) def _setup_dataloader_from_config(self, cfg: DictConfig): if cfg.get("load_from_cached_dataset", False): logging.info('Loading from cached dataset %s' % (cfg.src_file_name)) if cfg.src_file_name != cfg.tgt_file_name: raise ValueError( "src must be equal to target for cached dataset") dataset = pickle.load(open(cfg.src_file_name, 'rb')) dataset.reverse_lang_direction = cfg.get("reverse_lang_direction", False) elif cfg.get("use_tarred_dataset", False): if cfg.get('tar_files') is None: raise FileNotFoundError("Could not find tarred dataset.") logging.info(f'Loading from tarred dataset {cfg.get("tar_files")}') if cfg.get("metadata_file", None) is None: raise FileNotFoundError( "Could not find metadata path in config") dataset = TarredTranslationDataset( text_tar_filepaths=cfg.tar_files, metadata_path=cfg.metadata_file, encoder_tokenizer=self.encoder_tokenizer, decoder_tokenizer=self.decoder_tokenizer, shuffle_n=cfg.get("tar_shuffle_n", 100), shard_strategy=cfg.get("shard_strategy", "scatter"), global_rank=self.global_rank, world_size=self.world_size, reverse_lang_direction=cfg.get("reverse_lang_direction", False), ) return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), ) else: dataset = TranslationDataset( dataset_src=str(Path(cfg.src_file_name).expanduser()), dataset_tgt=str(Path(cfg.tgt_file_name).expanduser()), tokens_in_batch=cfg.tokens_in_batch, clean=cfg.get("clean", False), max_seq_length=cfg.get("max_seq_length", 512), min_seq_length=cfg.get("min_seq_length", 1), max_seq_length_diff=cfg.get("max_seq_length_diff", 512), max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512), cache_ids=cfg.get("cache_ids", False), cache_data_per_node=cfg.get("cache_data_per_node", False), use_cache=cfg.get("use_cache", False), reverse_lang_direction=cfg.get("reverse_lang_direction", False), ) dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer) if cfg.shuffle: sampler = pt_data.RandomSampler(dataset) else: sampler = pt_data.SequentialSampler(dataset) return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, sampler=sampler, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), ) def setup_pre_and_post_processing_utils(self, source_lang, target_lang): """ Creates source and target processor objects for input and output pre/post-processing. """ self.source_processor, self.target_processor = None, None if (source_lang == 'en' and target_lang == 'ja') or (source_lang == 'ja' and target_lang == 'en'): self.source_processor = EnJaProcessor(source_lang) self.target_processor = EnJaProcessor(target_lang) else: if source_lang == 'zh': self.source_processor = ChineseProcessor() if target_lang == 'zh': self.target_processor = ChineseProcessor() if source_lang is not None and source_lang not in ['ja', 'zh']: self.source_processor = MosesProcessor(source_lang) if target_lang is not None and target_lang not in ['ja', 'zh']: self.target_processor = MosesProcessor(target_lang) @torch.no_grad() def batch_translate( self, src: torch.LongTensor, src_mask: torch.LongTensor, ): """ Translates a minibatch of inputs from source language to target language. Args: src: minibatch of inputs in the src language (batch x seq_len) src_mask: mask tensor indicating elements to be ignored (batch x seq_len) Returns: translations: a list strings containing detokenized translations inputs: a list of string containing detokenized inputs """ mode = self.training try: self.eval() src_hiddens = self.encoder(input_ids=src, encoder_mask=src_mask) beam_results = self.beam_search(encoder_hidden_states=src_hiddens, encoder_input_mask=src_mask) beam_results = self.filter_predicted_ids(beam_results) translations = [ self.decoder_tokenizer.ids_to_text(tr) for tr in beam_results.cpu().numpy() ] inputs = [ self.encoder_tokenizer.ids_to_text(inp) for inp in src.cpu().numpy() ] if self.target_processor is not None: translations = [ self.target_processor.detokenize(translation.split(' ')) for translation in translations ] if self.source_processor is not None: inputs = [ self.source_processor.detokenize(item.split(' ')) for item in inputs ] finally: self.train(mode=mode) return inputs, translations # TODO: We should drop source/target_lang arguments in favor of using self.src/tgt_language @torch.no_grad() def translate(self, text: List[str], source_lang: str = None, target_lang: str = None) -> List[str]: """ Translates list of sentences from source language to target language. Should be regular text, this method performs its own tokenization/de-tokenization Args: text: list of strings to translate source_lang: if not None, corresponding MosesTokenizer and MosesPunctNormalizer will be run target_lang: if not None, corresponding MosesDecokenizer will be run Returns: list of translated strings """ # __TODO__: This will reset both source and target processors even if you want to reset just one. if source_lang is not None or target_lang is not None: self.setup_pre_and_post_processing_utils(source_lang, target_lang) mode = self.training try: self.eval() inputs = [] for txt in text: if self.source_processor is not None: txt = self.source_processor.normalize(txt) txt = self.source_processor.tokenize(txt) ids = self.encoder_tokenizer.text_to_ids(txt) ids = [self.encoder_tokenizer.bos_id ] + ids + [self.encoder_tokenizer.eos_id] inputs.append(ids) max_len = max(len(txt) for txt in inputs) src_ids_ = np.ones( (len(inputs), max_len)) * self.encoder_tokenizer.pad_id for i, txt in enumerate(inputs): src_ids_[i][:len(txt)] = txt src_mask = torch.FloatTensor( (src_ids_ != self.encoder_tokenizer.pad_id)).to(self.device) src = torch.LongTensor(src_ids_).to(self.device) _, translations = self.batch_translate(src, src_mask) finally: self.train(mode=mode) return translations @classmethod def list_available_models(cls) -> Optional[Dict[str, str]]: pass
class MTEncDecModel(EncDecNLPModel): """ Encoder-decoder machine translation model. """ def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None): cfg = model_utils.convert_model_config_to_dict_config(cfg) # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 self.world_size = 1 if trainer is not None: self.world_size = trainer.num_nodes * trainer.num_gpus cfg = model_utils.maybe_update_config_version(cfg) self.src_language = cfg.get("src_language", None) self.tgt_language = cfg.get("tgt_language", None) self.multilingual = cfg.get("multilingual", False) self.multilingual_ids = [] # Instantiates tokenizers and register to be saved with NeMo Model archive # After this call, ther will be self.encoder_tokenizer and self.decoder_tokenizer # Which can convert between tokens and token_ids for SRC and TGT languages correspondingly. self.setup_enc_dec_tokenizers( encoder_tokenizer_library=cfg.encoder_tokenizer.get('library', 'yttm'), encoder_tokenizer_model=cfg.encoder_tokenizer.get('tokenizer_model'), encoder_bpe_dropout=cfg.encoder_tokenizer.get('bpe_dropout', 0.0) if cfg.encoder_tokenizer.get('bpe_dropout', 0.0) is not None else 0.0, encoder_model_name=cfg.encoder.get('model_name') if hasattr(cfg.encoder, 'model_name') else None, decoder_tokenizer_library=cfg.decoder_tokenizer.get('library', 'yttm'), decoder_tokenizer_model=cfg.decoder_tokenizer.tokenizer_model, decoder_bpe_dropout=cfg.decoder_tokenizer.get('bpe_dropout', 0.0) if cfg.decoder_tokenizer.get('bpe_dropout', 0.0) is not None else 0.0, decoder_model_name=cfg.decoder.get('model_name') if hasattr(cfg.decoder, 'model_name') else None, ) if self.multilingual: if isinstance(self.src_language, ListConfig) and isinstance(self.tgt_language, ListConfig): raise ValueError( "cfg.src_language and cfg.tgt_language cannot both be lists. We only support many-to-one or one-to-many multilingual models." ) elif isinstance(self.src_language, ListConfig): for lng in self.src_language: self.multilingual_ids.append(self.encoder_tokenizer.token_to_id("<" + lng + ">")) elif isinstance(self.tgt_language, ListConfig): for lng in self.tgt_language: self.multilingual_ids.append(self.encoder_tokenizer.token_to_id("<" + lng + ">")) else: raise ValueError( "Expect either cfg.src_language or cfg.tgt_language to be a list when multilingual=True." ) if isinstance(self.src_language, ListConfig): self.tgt_language = [self.tgt_language] * len(self.src_language) else: self.src_language = [self.src_language] * len(self.tgt_language) self.source_processor_list = [] self.target_processor_list = [] for src_lng, tgt_lng in zip(self.src_language, self.tgt_language): src_prcsr, tgt_prscr = self.setup_pre_and_post_processing_utils( source_lang=src_lng, target_lang=tgt_lng ) self.source_processor_list.append(src_prcsr) self.target_processor_list.append(tgt_prscr) else: # After this call, the model will have self.source_processor and self.target_processor objects self.setup_pre_and_post_processing_utils(source_lang=self.src_language, target_lang=self.tgt_language) self.multilingual_ids = [None] # TODO: Why is this base constructor call so late in the game? super().__init__(cfg=cfg, trainer=trainer) # encoder from NeMo, Megatron-LM, or HuggingFace encoder_cfg_dict = OmegaConf.to_container(cfg.get('encoder')) encoder_cfg_dict['vocab_size'] = self.encoder_vocab_size library = encoder_cfg_dict.pop('library', 'nemo') model_name = encoder_cfg_dict.pop('model_name', None) pretrained = encoder_cfg_dict.pop('pretrained', False) checkpoint_file = encoder_cfg_dict.pop('checkpoint_file', None) self.encoder = get_transformer( library=library, model_name=model_name, pretrained=pretrained, config_dict=encoder_cfg_dict, encoder=True, pre_ln_final_layer_norm=encoder_cfg_dict.get('pre_ln_final_layer_norm', False), checkpoint_file=checkpoint_file, ) # decoder from NeMo, Megatron-LM, or HuggingFace decoder_cfg_dict = OmegaConf.to_container(cfg.get('decoder')) decoder_cfg_dict['vocab_size'] = self.decoder_vocab_size library = decoder_cfg_dict.pop('library', 'nemo') model_name = decoder_cfg_dict.pop('model_name', None) pretrained = decoder_cfg_dict.pop('pretrained', False) decoder_cfg_dict['hidden_size'] = self.encoder.hidden_size self.decoder = get_transformer( library=library, model_name=model_name, pretrained=pretrained, config_dict=decoder_cfg_dict, encoder=False, pre_ln_final_layer_norm=decoder_cfg_dict.get('pre_ln_final_layer_norm', False), ) self.log_softmax = TokenClassifier( hidden_size=self.decoder.hidden_size, num_classes=self.decoder_vocab_size, activation=cfg.head.activation, log_softmax=cfg.head.log_softmax, dropout=cfg.head.dropout, use_transformer_init=cfg.head.use_transformer_init, ) self.beam_search = BeamSearchSequenceGenerator( embedding=self.decoder.embedding, decoder=self.decoder.decoder, log_softmax=self.log_softmax, max_sequence_length=self.decoder.max_sequence_length, beam_size=cfg.beam_size, bos=self.decoder_tokenizer.bos_id, pad=self.decoder_tokenizer.pad_id, eos=self.decoder_tokenizer.eos_id, len_pen=cfg.len_pen, max_delta_length=cfg.max_generation_delta, ) # tie weights of embedding and softmax matrices self.log_softmax.mlp.layer0.weight = self.decoder.embedding.token_embedding.weight # TODO: encoder and decoder with different hidden size? std_init_range = 1 / self.encoder.hidden_size ** 0.5 # initialize weights if not using pretrained encoder/decoder if not self._cfg.encoder.get('pretrained', False): self.encoder.apply(lambda module: transformer_weights_init(module, std_init_range)) if not self._cfg.decoder.get('pretrained', False): self.decoder.apply(lambda module: transformer_weights_init(module, std_init_range)) self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range)) self.loss_fn = SmoothedCrossEntropyLoss( pad_id=self.decoder_tokenizer.pad_id, label_smoothing=cfg.label_smoothing ) self.eval_loss_fn = NLLLoss(ignore_index=self.decoder_tokenizer.pad_id) def filter_predicted_ids(self, ids): ids[ids >= self.decoder_tokenizer.vocab_size] = self.decoder_tokenizer.unk_id return ids @typecheck() def forward(self, src, src_mask, tgt, tgt_mask): src_hiddens = self.encoder(input_ids=src, encoder_mask=src_mask) tgt_hiddens = self.decoder( input_ids=tgt, decoder_mask=tgt_mask, encoder_embeddings=src_hiddens, encoder_mask=src_mask ) log_probs = self.log_softmax(hidden_states=tgt_hiddens) return log_probs def training_step(self, batch, batch_idx): """ Lightning calls this inside the training loop with the data from the training dataloader passed in as `batch`. """ # forward pass for i in range(len(batch)): if batch[i].ndim == 3: # Dataset returns already batched data and the first dimension of size 1 added by DataLoader # is excess. batch[i] = batch[i].squeeze(dim=0) src_ids, src_mask, tgt_ids, tgt_mask, labels = batch log_probs = self(src_ids, src_mask, tgt_ids, tgt_mask) train_loss = self.loss_fn(log_probs=log_probs, labels=labels) tensorboard_logs = { 'train_loss': train_loss, 'lr': self._optimizer.param_groups[0]['lr'], } return {'loss': train_loss, 'log': tensorboard_logs} def eval_step(self, batch, batch_idx, mode, dataloader_idx=0): for i in range(len(batch)): if batch[i].ndim == 3: # Dataset returns already batched data and the first dimension of size 1 added by DataLoader # is excess. batch[i] = batch[i].squeeze(dim=0) if self.multilingual: self.source_processor = self.source_processor_list[dataloader_idx] self.target_processor = self.target_processor_list[dataloader_idx] src_ids, src_mask, tgt_ids, tgt_mask, labels = batch log_probs = self(src_ids, src_mask, tgt_ids, tgt_mask) eval_loss = self.eval_loss_fn(log_probs=log_probs, labels=labels) # this will run encoder twice -- TODO: potentially fix _, translations = self.batch_translate(src=src_ids, src_mask=src_mask) if dataloader_idx == 0: getattr(self, f'{mode}_loss')(loss=eval_loss, num_measurements=log_probs.shape[0] * log_probs.shape[1]) else: getattr(self, f'{mode}_loss_{dataloader_idx}')( loss=eval_loss, num_measurements=log_probs.shape[0] * log_probs.shape[1] ) np_tgt = tgt_ids.detach().cpu().numpy() ground_truths = [self.decoder_tokenizer.ids_to_text(tgt) for tgt in np_tgt] ground_truths = [self.target_processor.detokenize(tgt.split(' ')) for tgt in ground_truths] num_non_pad_tokens = np.not_equal(np_tgt, self.decoder_tokenizer.pad_id).sum().item() return { 'translations': translations, 'ground_truths': ground_truths, 'num_non_pad_tokens': num_non_pad_tokens, } def test_step(self, batch, batch_idx, dataloader_idx=0): return self.eval_step(batch, batch_idx, 'test', dataloader_idx) @rank_zero_only def log_param_stats(self): for name, p in self.named_parameters(): if p.requires_grad: self.trainer.logger.experiment.add_histogram(name + '_hist', p, global_step=self.global_step) self.trainer.logger.experiment.add_scalars( name, {'mean': p.mean(), 'stddev': p.std(), 'max': p.max(), 'min': p.min()}, global_step=self.global_step, ) def validation_step(self, batch, batch_idx, dataloader_idx=0): """ Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """ return self.eval_step(batch, batch_idx, 'val', dataloader_idx) def eval_epoch_end(self, outputs, mode): # if user specifies one validation dataloader, then PTL reverts to giving a list of dictionary instead of a list of list of dictionary if isinstance(outputs[0], dict): outputs = [outputs] loss_list = [] sb_score_list = [] for dataloader_idx, output in enumerate(outputs): if dataloader_idx == 0: eval_loss = getattr(self, f'{mode}_loss').compute() else: eval_loss = getattr(self, f'{mode}_loss_{dataloader_idx}').compute() translations = list(itertools.chain(*[x['translations'] for x in output])) ground_truths = list(itertools.chain(*[x['ground_truths'] for x in output])) assert len(translations) == len(ground_truths) # Gather translations and ground truths from all workers tr_and_gt = [None for _ in range(self.world_size)] # we also need to drop pairs where ground truth is an empty string dist.all_gather_object( tr_and_gt, [(t, g) for (t, g) in zip(translations, ground_truths) if g.strip() != ''] ) if self.global_rank == 0: _translations = [] _ground_truths = [] for rank in range(0, self.world_size): _translations += [t for (t, g) in tr_and_gt[rank]] _ground_truths += [g for (t, g) in tr_and_gt[rank]] if self.tgt_language in ['ja']: sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="ja-mecab") elif self.tgt_language in ['zh']: sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="zh") else: sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="13a") # because the reduction op later is average (over word_size) sb_score = sacre_bleu.score * self.world_size dataset_name = "Validation" if mode == 'val' else "Test" logging.info( f"Dataset name: {dataset_name}, Dataloader index: {dataloader_idx}, Set size: {len(translations)}" ) logging.info( f"Dataset name: {dataset_name}, Dataloader index: {dataloader_idx}, Val Loss = {eval_loss}" ) logging.info( f"Dataset name: {dataset_name}, Dataloader index: {dataloader_idx}, Sacre BLEU = {sb_score / self.world_size}" ) logging.info( f"Dataset name: {dataset_name}, Dataloader index: {dataloader_idx}, Translation Examples:" ) for i in range(0, 3): ind = random.randint(0, len(translations) - 1) logging.info(" " + '\u0332'.join(f"Example {i}:")) logging.info(f" Prediction: {translations[ind]}") logging.info(f" Ground Truth: {ground_truths[ind]}") else: sb_score = 0.0 loss_list.append(eval_loss.cpu().numpy()) sb_score_list.append(sb_score) if dataloader_idx == 0: self.log(f"{mode}_loss", eval_loss, sync_dist=True) self.log(f"{mode}_sacreBLEU", sb_score, sync_dist=True) getattr(self, f'{mode}_loss').reset() else: self.log(f"{mode}_loss_dl_index_{dataloader_idx}", eval_loss, sync_dist=True) self.log(f"{mode}_sacreBLEU_dl_index_{dataloader_idx}", sb_score, sync_dist=True) getattr(self, f'{mode}_loss_{dataloader_idx}').reset() if len(loss_list) > 1: self.log(f"{mode}_loss_avg", np.mean(loss_list), sync_dist=True) self.log(f"{mode}_sacreBLEU_avg", np.mean(sb_score_list), sync_dist=True) def validation_epoch_end(self, outputs): """ Called at the end of validation to aggregate outputs. :param outputs: list of individual outputs of each validation step. """ self.eval_epoch_end(outputs, 'val') def test_epoch_end(self, outputs): self.eval_epoch_end(outputs, 'test') def setup_enc_dec_tokenizers( self, encoder_tokenizer_library=None, encoder_tokenizer_model=None, encoder_bpe_dropout=0.0, encoder_model_name=None, decoder_tokenizer_library=None, decoder_tokenizer_model=None, decoder_bpe_dropout=0.0, decoder_model_name=None, ): supported_tokenizers = ['yttm', 'huggingface', 'sentencepiece', 'megatron'] if ( encoder_tokenizer_library not in supported_tokenizers or decoder_tokenizer_library not in supported_tokenizers ): raise NotImplementedError(f"Currently we only support tokenizers in {supported_tokenizers}.") self.encoder_tokenizer = get_nmt_tokenizer( library=encoder_tokenizer_library, tokenizer_model=self.register_artifact("encoder_tokenizer.tokenizer_model", encoder_tokenizer_model), bpe_dropout=encoder_bpe_dropout, model_name=encoder_model_name, vocab_file=None, special_tokens=None, use_fast=False, ) self.decoder_tokenizer = get_nmt_tokenizer( library=decoder_tokenizer_library, tokenizer_model=self.register_artifact("decoder_tokenizer.tokenizer_model", decoder_tokenizer_model), bpe_dropout=decoder_bpe_dropout, model_name=decoder_model_name, vocab_file=None, special_tokens=None, use_fast=False, ) def setup_training_data(self, train_data_config: Optional[DictConfig]): self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config) def setup_multiple_validation_data(self, val_data_config: Union[DictConfig, Dict]): self.setup_validation_data(self._cfg.get('validation_ds')) def setup_multiple_test_data(self, test_data_config: Union[DictConfig, Dict]): self.setup_test_data(self._cfg.get('test_ds')) def setup_validation_data(self, val_data_config: Optional[DictConfig]): self._validation_dl = self._setup_eval_dataloader_from_config(cfg=val_data_config) # instantiate Torchmetric for each val dataloader if self._validation_dl is not None: for dataloader_idx in range(len(self._validation_dl)): if dataloader_idx == 0: setattr( self, f'val_loss', GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True), ) else: setattr( self, f'val_loss_{dataloader_idx}', GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True), ) def setup_test_data(self, test_data_config: Optional[DictConfig]): self._test_dl = self._setup_eval_dataloader_from_config(cfg=test_data_config) # instantiate Torchmetric for each test dataloader if self._test_dl is not None: for dataloader_idx in range(len(self._test_dl)): if dataloader_idx == 0: setattr( self, f'test_loss', GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True), ) else: setattr( self, f'test_loss_{dataloader_idx}', GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True), ) def _setup_dataloader_from_config(self, cfg: DictConfig): if cfg.get("use_tarred_dataset", False): if cfg.get("metadata_file") is None: raise FileNotFoundError("Trying to use tarred data set but could not find metadata path in config.") else: if not self.multilingual: metadata_file_list = [cfg.get('metadata_file')] else: metadata_file_list = cfg.get('metadata_file') datasets = [] for idx, metadata_file in enumerate(metadata_file_list): with open(metadata_file) as metadata_reader: metadata = json.load(metadata_reader) if cfg.get('tar_files') is None: tar_files = metadata.get('tar_files') if tar_files is not None: logging.info(f'Loading from tarred dataset {tar_files}') else: raise FileNotFoundError("Could not find tarred dataset in config or metadata.") else: tar_files = cfg.get('tar_files') if self.multilingual: tar_files = tar_files[idx] if metadata.get('tar_files') is not None: logging.info( f'Tar file paths found in both cfg and metadata using one in cfg by default - {tar_files}' ) dataset = TarredTranslationDataset( text_tar_filepaths=tar_files, metadata_path=metadata_file, encoder_tokenizer=self.encoder_tokenizer, decoder_tokenizer=self.decoder_tokenizer, shuffle_n=cfg.get("tar_shuffle_n", 100), shard_strategy=cfg.get("shard_strategy", "scatter"), global_rank=self.global_rank, world_size=self.world_size, reverse_lang_direction=cfg.get("reverse_lang_direction", False), prepend_id=self.multilingual_ids[idx], ) datasets.append(dataset) if self.multilingual: dataset = ConcatDataset( datasets=datasets, sampling_technique=cfg.get('concat_sampling_technique'), sampling_temperature=cfg.get('concat_sampling_temperature'), sampling_probabilities=cfg.get('concat_sampling_probabilities'), global_rank=self.global_rank, world_size=self.world_size, ) else: dataset = datasets[0] return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), ) else: if not self.multilingual: src_file_list = [cfg.src_file_name] tgt_file_list = [cfg.tgt_file_name] else: src_file_list = cfg.src_file_name tgt_file_list = cfg.tgt_file_name if len(src_file_list) != len(tgt_file_list): raise ValueError( 'The same number of filepaths must be passed in for source and target while training multilingual.' ) datasets = [] for idx, src_file in enumerate(src_file_list): dataset = TranslationDataset( dataset_src=str(Path(src_file).expanduser()), dataset_tgt=str(Path(tgt_file_list[idx]).expanduser()), tokens_in_batch=cfg.tokens_in_batch, clean=cfg.get("clean", False), max_seq_length=cfg.get("max_seq_length", 512), min_seq_length=cfg.get("min_seq_length", 1), max_seq_length_diff=cfg.get("max_seq_length_diff", 512), max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512), cache_ids=cfg.get("cache_ids", False), cache_data_per_node=cfg.get("cache_data_per_node", False), use_cache=cfg.get("use_cache", False), reverse_lang_direction=cfg.get("reverse_lang_direction", False), prepend_id=self.multilingual_ids[idx], ) dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer) datasets.append(dataset) if self.multilingual: dataset = ConcatDataset( datasets=datasets, shuffle=cfg.get('shuffle'), sampling_technique=cfg.get('concat_sampling_technique'), sampling_temperature=cfg.get('concat_sampling_temperature'), sampling_probabilities=cfg.get('concat_sampling_probabilities'), global_rank=self.global_rank, world_size=self.world_size, ) return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), ) else: dataset = datasets[0] if cfg.shuffle: sampler = pt_data.RandomSampler(dataset) else: sampler = pt_data.SequentialSampler(dataset) return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, sampler=sampler, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), ) def replace_beam_with_sampling(self, topk=500): self.beam_search = TopKSequenceGenerator( embedding=self.decoder.embedding, decoder=self.decoder.decoder, log_softmax=self.log_softmax, max_sequence_length=self.beam_search.max_seq_length, beam_size=topk, bos=self.decoder_tokenizer.bos_id, pad=self.decoder_tokenizer.pad_id, eos=self.decoder_tokenizer.eos_id, ) def _setup_eval_dataloader_from_config(self, cfg: DictConfig): src_file_name = cfg.get('src_file_name') tgt_file_name = cfg.get('tgt_file_name') if src_file_name is None or tgt_file_name is None: raise ValueError( 'Validation dataloader needs both cfg.src_file_name and cfg.tgt_file_name to not be None.' ) else: # convert src_file_name and tgt_file_name to list of strings if isinstance(src_file_name, str): src_file_list = [src_file_name] elif isinstance(src_file_name, ListConfig): src_file_list = src_file_name else: raise ValueError("cfg.src_file_name must be string or list of strings") if isinstance(tgt_file_name, str): tgt_file_list = [tgt_file_name] elif isinstance(tgt_file_name, ListConfig): tgt_file_list = tgt_file_name else: raise ValueError("cfg.tgt_file_name must be string or list of strings") if len(src_file_list) != len(tgt_file_list): raise ValueError('The same number of filepaths must be passed in for source and target validation.') dataloaders = [] prepend_idx = 0 for idx, src_file in enumerate(src_file_list): if self.multilingual: prepend_idx = idx dataset = TranslationDataset( dataset_src=str(Path(src_file).expanduser()), dataset_tgt=str(Path(tgt_file_list[idx]).expanduser()), tokens_in_batch=cfg.tokens_in_batch, clean=cfg.get("clean", False), max_seq_length=cfg.get("max_seq_length", 512), min_seq_length=cfg.get("min_seq_length", 1), max_seq_length_diff=cfg.get("max_seq_length_diff", 512), max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512), cache_ids=cfg.get("cache_ids", False), cache_data_per_node=cfg.get("cache_data_per_node", False), use_cache=cfg.get("use_cache", False), reverse_lang_direction=cfg.get("reverse_lang_direction", False), prepend_id=self.multilingual_ids[prepend_idx], ) dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer) if cfg.shuffle: sampler = pt_data.RandomSampler(dataset) else: sampler = pt_data.SequentialSampler(dataset) dataloader = torch.utils.data.DataLoader( dataset=dataset, batch_size=1, sampler=sampler, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), ) dataloaders.append(dataloader) return dataloaders def setup_pre_and_post_processing_utils(self, source_lang, target_lang): """ Creates source and target processor objects for input and output pre/post-processing. """ self.source_processor, self.target_processor = None, None if (source_lang == 'en' and target_lang == 'ja') or (source_lang == 'ja' and target_lang == 'en'): self.source_processor = EnJaProcessor(source_lang) self.target_processor = EnJaProcessor(target_lang) else: if source_lang == 'zh': self.source_processor = ChineseProcessor() if target_lang == 'zh': self.target_processor = ChineseProcessor() if source_lang is not None and source_lang not in ['ja', 'zh']: self.source_processor = MosesProcessor(source_lang) if target_lang is not None and target_lang not in ['ja', 'zh']: self.target_processor = MosesProcessor(target_lang) return self.source_processor, self.target_processor @torch.no_grad() def batch_translate( self, src: torch.LongTensor, src_mask: torch.LongTensor, ): """ Translates a minibatch of inputs from source language to target language. Args: src: minibatch of inputs in the src language (batch x seq_len) src_mask: mask tensor indicating elements to be ignored (batch x seq_len) Returns: translations: a list strings containing detokenized translations inputs: a list of string containing detokenized inputs """ mode = self.training try: self.eval() src_hiddens = self.encoder(input_ids=src, encoder_mask=src_mask) beam_results = self.beam_search(encoder_hidden_states=src_hiddens, encoder_input_mask=src_mask) beam_results = self.filter_predicted_ids(beam_results) translations = [self.decoder_tokenizer.ids_to_text(tr) for tr in beam_results.cpu().numpy()] inputs = [self.encoder_tokenizer.ids_to_text(inp) for inp in src.cpu().numpy()] if self.target_processor is not None: translations = [ self.target_processor.detokenize(translation.split(' ')) for translation in translations ] if self.source_processor is not None: inputs = [self.source_processor.detokenize(item.split(' ')) for item in inputs] finally: self.train(mode=mode) return inputs, translations # TODO: We should drop source/target_lang arguments in favor of using self.src/tgt_language @torch.no_grad() def translate(self, text: List[str], source_lang: str = None, target_lang: str = None) -> List[str]: """ Translates list of sentences from source language to target language. Should be regular text, this method performs its own tokenization/de-tokenization Args: text: list of strings to translate source_lang: if not None, corresponding MosesTokenizer and MosesPunctNormalizer will be run target_lang: if not None, corresponding MosesDecokenizer will be run Returns: list of translated strings """ # __TODO__: This will reset both source and target processors even if you want to reset just one. if source_lang is not None or target_lang is not None: self.setup_pre_and_post_processing_utils(source_lang, target_lang) mode = self.training prepend_ids = [] if self.multilingual: if source_lang is None or target_lang is None: raise ValueError("Expect source_lang and target_lang to infer for multilingual model.") src_symbol = self.encoder_tokenizer.token_to_id('<' + source_lang + '>') tgt_symbol = self.encoder_tokenizer.token_to_id('<' + target_lang + '>') prepend_ids = [src_symbol if src_symbol in self.multilingual_ids else tgt_symbol] try: self.eval() inputs = [] for txt in text: if self.source_processor is not None: txt = self.source_processor.normalize(txt) txt = self.source_processor.tokenize(txt) ids = self.encoder_tokenizer.text_to_ids(txt) ids = prepend_ids + [self.encoder_tokenizer.bos_id] + ids + [self.encoder_tokenizer.eos_id] inputs.append(ids) max_len = max(len(txt) for txt in inputs) src_ids_ = np.ones((len(inputs), max_len)) * self.encoder_tokenizer.pad_id for i, txt in enumerate(inputs): src_ids_[i][: len(txt)] = txt src_mask = torch.FloatTensor((src_ids_ != self.encoder_tokenizer.pad_id)).to(self.device) src = torch.LongTensor(src_ids_).to(self.device) _, translations = self.batch_translate(src, src_mask) finally: self.train(mode=mode) return translations @classmethod def list_available_models(cls) -> Optional[Dict[str, str]]: """ This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. Returns: List of available pre-trained models. """ result = [] model = PretrainedModelInfo( pretrained_model_name="nmt_en_de_transformer12x2", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_en_de_transformer12x2/versions/1.0.0rc1/files/nmt_en_de_transformer12x2.nemo", description="En->De translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_en_de_transformer12x2", ) result.append(model) model = PretrainedModelInfo( pretrained_model_name="nmt_de_en_transformer12x2", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_de_en_transformer12x2/versions/1.0.0rc1/files/nmt_de_en_transformer12x2.nemo", description="De->En translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_de_en_transformer12x2", ) result.append(model) model = PretrainedModelInfo( pretrained_model_name="nmt_en_es_transformer12x2", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_en_es_transformer12x2/versions/1.0.0rc1/files/nmt_en_es_transformer12x2.nemo", description="En->Es translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_en_es_transformer12x2", ) result.append(model) model = PretrainedModelInfo( pretrained_model_name="nmt_es_en_transformer12x2", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_es_en_transformer12x2/versions/1.0.0rc1/files/nmt_es_en_transformer12x2.nemo", description="Es->En translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_es_en_transformer12x2", ) result.append(model) model = PretrainedModelInfo( pretrained_model_name="nmt_en_fr_transformer12x2", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_en_fr_transformer12x2/versions/1.0.0rc1/files/nmt_en_fr_transformer12x2.nemo", description="En->Fr translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_en_fr_transformer12x2", ) result.append(model) model = PretrainedModelInfo( pretrained_model_name="nmt_fr_en_transformer12x2", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_fr_en_transformer12x2/versions/1.0.0rc1/files/nmt_fr_en_transformer12x2.nemo", description="Fr->En translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_fr_en_transformer12x2", ) result.append(model) model = PretrainedModelInfo( pretrained_model_name="nmt_en_ru_transformer6x6", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_en_ru_transformer6x6/versions/1.0.0rc1/files/nmt_en_ru_transformer6x6.nemo", description="En->Ru translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_en_ru_transformer6x6", ) result.append(model) model = PretrainedModelInfo( pretrained_model_name="nmt_ru_en_transformer6x6", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_ru_en_transformer6x6/versions/1.0.0rc1/files/nmt_ru_en_transformer6x6.nemo", description="Ru->En translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_ru_en_transformer6x6", ) result.append(model) model = PretrainedModelInfo( pretrained_model_name="nmt_zh_en_transformer6x6", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_zh_en_transformer6x6/versions/1.0.0rc1/files/nmt_zh_en_transformer6x6.nemo", description="Zh->En translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_zh_en_transformer6x6", ) result.append(model) model = PretrainedModelInfo( pretrained_model_name="nmt_en_zh_transformer6x6", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/nmt_en_zh_transformer6x6/versions/1.0.0rc1/files/nmt_en_zh_transformer6x6.nemo", description="En->Zh translation model. See details here: https://ngc.nvidia.com/catalog/models/nvidia:nemo:nmt_en_zh_transformer6x6", ) result.append(model) return result
def __init__( self, w_words: List[str], s_words: List[str], inst_dir: str, start_idx: int, end_idx: int, lang: str, semiotic_class: str = None, ): processor = MosesProcessor(lang_id=lang) start_idx = max(start_idx, 0) end_idx = min(end_idx, len(w_words)) ctx_size = constants.DECODE_CTX_SIZE extra_id_0 = constants.EXTRA_ID_0 extra_id_1 = constants.EXTRA_ID_1 # Extract center words c_w_words = w_words[start_idx:end_idx] c_s_words = s_words[start_idx:end_idx] # Extract context w_left = w_words[max(0, start_idx - ctx_size):start_idx] w_right = w_words[end_idx:end_idx + ctx_size] s_left = s_words[max(0, start_idx - ctx_size):start_idx] s_right = s_words[end_idx:end_idx + ctx_size] # Process sil words and self words for jx in range(len(s_left)): if s_left[jx] == constants.SIL_WORD: s_left[jx] = '' if s_left[jx] == constants.SELF_WORD: s_left[jx] = w_left[jx] for jx in range(len(s_right)): if s_right[jx] == constants.SIL_WORD: s_right[jx] = '' if s_right[jx] == constants.SELF_WORD: s_right[jx] = w_right[jx] for jx in range(len(c_s_words)): if c_s_words[jx] == constants.SIL_WORD: c_s_words[jx] = c_w_words[jx] if inst_dir == constants.INST_BACKWARD: c_w_words[jx] = '' c_s_words[jx] = '' if c_s_words[jx] == constants.SELF_WORD: c_s_words[jx] = c_w_words[jx] # Extract input_words and output_words c_w_words = processor.tokenize(' '.join(c_w_words)).split() c_s_words = processor.tokenize(' '.join(c_s_words)).split() # for cases when nearby words are actually multiple tokens, e.g. '1974,' w_left = processor.tokenize( ' '.join(w_left)).split()[-constants.DECODE_CTX_SIZE:] w_right = processor.tokenize( ' '.join(w_right)).split()[:constants.DECODE_CTX_SIZE] w_input = w_left + [extra_id_0] + c_w_words + [extra_id_1] + w_right s_input = s_left + [extra_id_0] + c_s_words + [extra_id_1] + s_right if inst_dir == constants.INST_BACKWARD: input_center_words = c_s_words input_words = [constants.ITN_PREFIX] + s_input output_words = c_w_words if inst_dir == constants.INST_FORWARD: input_center_words = c_w_words input_words = [constants.TN_PREFIX] + w_input output_words = c_s_words # Finalize self.input_str = ' '.join(input_words) self.input_center_str = ' '.join(input_center_words) self.output_str = ' '.join(output_words) self.direction = inst_dir self.semiotic_class = semiotic_class
def __init__(self, input_file: str, mode: str, lang: str): self.lang = lang insts = read_data_file(input_file, lang=lang) processor = MosesProcessor(lang_id=lang) # Build inputs and targets self.directions, self.inputs, self.targets, self.classes, self.nb_spans, self.span_starts, self.span_ends = ( [], [], [], [], [], [], [], ) for (classes, w_words, s_words) in insts: # Extract words that are not punctuations for direction in constants.INST_DIRECTIONS: if direction == constants.INST_BACKWARD: if mode == constants.TN_MODE: continue # ITN mode ( processed_w_words, processed_s_words, processed_classes, processed_nb_spans, processed_s_span_starts, processed_s_span_ends, ) = ([], [], [], 0, [], []) s_word_idx = 0 for cls, w_word, s_word in zip(classes, w_words, s_words): if s_word == constants.SIL_WORD: continue elif s_word == constants.SELF_WORD: processed_s_words.append(w_word) else: processed_s_words.append(s_word) s_word_last = processor.tokenize( processed_s_words.pop()).split() processed_s_words.append(" ".join(s_word_last)) num_tokens = len(s_word_last) processed_nb_spans += 1 processed_classes.append(cls) processed_s_span_starts.append(s_word_idx) s_word_idx += num_tokens processed_s_span_ends.append(s_word_idx) processed_w_words.append(w_word) self.span_starts.append(processed_s_span_starts) self.span_ends.append(processed_s_span_ends) self.classes.append(processed_classes) self.nb_spans.append(processed_nb_spans) input_words = ' '.join(processed_s_words) # Update self.directions, self.inputs, self.targets self.directions.append(direction) self.inputs.append(input_words) self.targets.append( processed_w_words ) # is list of lists where inner list contains target tokens (not words) # TN mode elif direction == constants.INST_FORWARD: if mode == constants.ITN_MODE: continue ( processed_w_words, processed_s_words, processed_classes, processed_nb_spans, w_span_starts, w_span_ends, ) = ([], [], [], 0, [], []) w_word_idx = 0 for cls, w_word, s_word in zip(classes, w_words, s_words): # TN forward mode # this is done for cases like `do n't`, this w_word will be treated as 2 tokens w_word = processor.tokenize(w_word).split() num_tokens = len(w_word) if s_word in constants.SPECIAL_WORDS: processed_s_words.append(" ".join(w_word)) else: processed_s_words.append(s_word) w_span_starts.append(w_word_idx) w_word_idx += num_tokens w_span_ends.append(w_word_idx) processed_nb_spans += 1 processed_classes.append(cls) processed_w_words.extend(w_word) self.span_starts.append(w_span_starts) self.span_ends.append(w_span_ends) self.classes.append(processed_classes) self.nb_spans.append(processed_nb_spans) input_words = ' '.join(processed_w_words) # Update self.directions, self.inputs, self.targets self.directions.append(direction) self.inputs.append(input_words) self.targets.append( processed_s_words ) # is list of lists where inner list contains target tokens (not words) self.examples = list( zip( self.directions, self.inputs, self.targets, self.classes, self.nb_spans, self.span_starts, self.span_ends, ))