def __init__(self, timer_kwargs={}): self.timer = timers.NamedTimer(**timer_kwargs)
def eval_step(self, batch, batch_idx, mode, dataloader_idx=0): if self.log_timing: timer = timers.NamedTimer() else: timer = None for i in range(len(batch)): if batch[i].ndim == 3: # Dataset returns already batched data and the first dimension of size 1 added by DataLoader # is excess. batch[i] = batch[i].squeeze(dim=0) if self.multilingual: self.source_processor = self.source_processor_list[dataloader_idx] self.target_processor = self.target_processor_list[dataloader_idx] src_ids, src_mask, tgt_ids, tgt_mask, labels = batch z, z_mean, z_logv, z_mask, tgt_log_probs = self(src_ids, src_mask, tgt_ids, tgt_mask, timer=timer) eval_loss, info_dict = self.loss( z=z, z_mean=z_mean, z_logv=z_logv, z_mask=z_mask, tgt_log_probs=tgt_log_probs, tgt=tgt_ids, tgt_mask=tgt_mask, tgt_labels=labels, train=False, return_info=True, ) # pass cache to sampler in order to reuse encoder's output cache = dict( z=z, z_mean=z_mean, z_mask=z_mask, timer=timer, ) inputs, translations = self.batch_translate(src=src_ids, src_mask=src_mask, cache=cache) num_measurements = labels.shape[0] * labels.shape[1] if dataloader_idx == 0: getattr(self, f'{mode}_loss')( loss=eval_loss, num_measurements=num_measurements, ) else: getattr(self, f'{mode}_loss_{dataloader_idx}')( loss=eval_loss, num_measurements=num_measurements, ) np_tgt = tgt_ids.detach().cpu().numpy() ground_truths = [ self.decoder_tokenizer.ids_to_text(tgt) for tgt in np_tgt ] ground_truths = [ self.target_processor.detokenize(tgt.split(' ')) for tgt in ground_truths ] num_non_pad_tokens = np.not_equal( np_tgt, self.decoder_tokenizer.pad_id).sum().item() # collect logs log_dict = { k: v.detach().cpu().numpy() if torch.is_tensor(v) else v for k, v in info_dict.items() } # add timing if required if timer is not None: for k, v in timer.export().items(): log_dict[f"{k}_timing"] = v return { 'inputs': inputs, 'translations': translations, 'ground_truths': ground_truths, 'num_non_pad_tokens': num_non_pad_tokens, 'log': log_dict, }
def translate( self, text: List[str], source_lang: str = None, target_lang: str = None, return_beam_scores: bool = False, log_timing: bool = False, ) -> List[str]: """ Translates list of sentences from source language to target language. Should be regular text, this method performs its own tokenization/de-tokenization Args: text: list of strings to translate source_lang: if not "ignore", corresponding MosesTokenizer and MosesPunctNormalizer will be run target_lang: if not "ignore", corresponding MosesDecokenizer will be run return_beam_scores: if True, returns a list of translations and their corresponding beam scores. log_timing: if True, prints timing information. Returns: list of translated strings """ # __TODO__: This will reset both source and target processors even if you want to reset just one. # NOTE: This will also set up appropriate source and target processors for a given src/tgt language for multilingual models instead of creating a list of them. if source_lang is not None or target_lang is not None: self.source_processor, self.target_processor = MTEncDecModel.setup_pre_and_post_processing_utils( source_lang, target_lang, self.encoder_tokenizer_library, self.decoder_tokenizer_library) mode = self.training prepend_ids = [] if self.multilingual: if source_lang is None or target_lang is None: raise ValueError( "Expect source_lang and target_lang to run inference for multilingual model." ) src_symbol = self.encoder_tokenizer.token_to_id('<' + source_lang + '>') tgt_symbol = self.encoder_tokenizer.token_to_id('<' + target_lang + '>') if src_symbol in self.multilingual_ids: prepend_ids = [src_symbol] elif tgt_symbol in self.multilingual_ids: prepend_ids = [tgt_symbol] if log_timing: timer = timers.NamedTimer() else: timer = None cache = { "timer": timer, } try: self.eval() src, src_mask = MTEncDecModel.prepare_inference_batch( text=text, prepend_ids=prepend_ids, target=False, source_processor=self.source_processor, target_processor=self.target_processor, encoder_tokenizer=self.encoder_tokenizer, decoder_tokenizer=self.decoder_tokenizer, device=self.device, ) predicted_tokens_ids, _ = self.decode( src, src_mask, src.size(1) + self._cfg. max_generation_delta, # Generate up to src-length + max generation delta. TODO: Implement better stopping when everything hits <EOS>. tokenizer=self.decoder_tokenizer, ) best_translations = self.postprocess_outputs( outputs=predicted_tokens_ids, tokenizer=self.decoder_tokenizer, processor=self.target_processor) return_val = best_translations finally: self.train(mode=mode) if log_timing: timing = timer.export() timing["mean_src_length"] = src_mask.sum().cpu().item( ) / src_mask.shape[0] tgt, tgt_mask = self.prepare_inference_batch( text=best_translations, prepend_ids=prepend_ids, target=True, source_processor=self.source_processor, target_processor=self.target_processor, encoder_tokenizer=self.encoder_tokenizer, decoder_tokenizer=self.decoder_tokenizer, device=self.device, ) timing["mean_tgt_length"] = tgt_mask.sum().cpu().item( ) / tgt_mask.shape[0] if type(return_val) is tuple: return_val = return_val + (timing, ) else: return_val = (return_val, timing) return return_val