def test_parse_numberize(self): src_dataset = data.InMemoryNumpyDataset() trg_dataset = data.InMemoryNumpyDataset() for _ in range(2): src_dataset.parse( self.src_txt_numberized, self.d, reverse_order=True, append_eos=False, already_numberized=True, ) trg_dataset.parse( self.trg_txt_numberized, self.d, reverse_order=False, append_eos=True, already_numberized=True, ) self.assertEqual(self.num_sentences, len(src_dataset)) self.assertEqual(self.num_sentences, len(trg_dataset)) for i in range(self.num_sentences): self.assertListEqual(self.src_ref[i], src_dataset[i].tolist()) self.assertListEqual( self.trg_ref[i] + [self.d.eos_index], trg_dataset[i].tolist() )
def test_subsample_pair_dataset(self): src_dataset = data.InMemoryNumpyDataset() trg_dataset = data.InMemoryNumpyDataset() for _ in range(5): src_dataset.parse(self.src_txt, self.d, reverse_order=True, append_eos=False) trg_dataset.parse(self.trg_txt, self.d, reverse_order=False, append_eos=True) pair_dataset = LanguagePairDataset( src=src_dataset, src_sizes=src_dataset.sizes, src_dict=self.d, tgt=trg_dataset, tgt_sizes=trg_dataset.sizes, tgt_dict=self.d, left_pad_source=False, ) data.subsample_pair_dataset(pair_dataset, 2) self.assertEqual(len(pair_dataset.src), 2) self.assertEqual(pair_dataset.src_sizes.size, 2) self.assertEqual(len(pair_dataset.tgt), 2) self.assertEqual(pair_dataset.tgt_sizes.size, 2)
def test_parse_multiling(self): prepend_dataset = data.InMemoryNumpyDataset() append_dataset = data.InMemoryNumpyDataset() corpora = [ data.MultilingualCorpusConfig( dialect_id=10, data_file=self.trg_txt, dict=self.d, oversampling=1 ), data.MultilingualCorpusConfig( dialect_id=11, data_file=self.trg_txt, dict=self.d, oversampling=1 ), ] lang1 = corpora[0].dialect_id lang2 = corpora[1].dialect_id prepend_dataset.parse_multilingual( corpora, reverse_order=False, append_eos=False, prepend_language_id=True ) append_dataset.parse_multilingual( corpora, reverse_order=False, append_eos=False, prepend_language_id=False ) self.assertEqual(2 * self.num_sentences, len(prepend_dataset)) self.assertEqual(2 * self.num_sentences, len(append_dataset)) for i in range(self.num_sentences): self.assertListEqual([lang1] + self.trg_ref[i], prepend_dataset[i].tolist()) self.assertListEqual(self.trg_ref[i] + [lang1], append_dataset[i].tolist()) self.assertListEqual( [lang2] + self.trg_ref[i], prepend_dataset[i + self.num_sentences].tolist(), ) self.assertListEqual( self.trg_ref[i] + [lang2], append_dataset[i + self.num_sentences].tolist(), )
def binarize_text_file( text_file: str, dictionary: Dictionary, output_path: str, append_eos: bool, reverse_order: bool, use_char_data: bool = False, embed_bytes: bool = False, char_dictionary: Optional[Dictionary] = None, already_numberized: bool = False, ) -> str: output_path = maybe_generate_temp_file_path(output_path) if use_char_data: dataset = char_data.InMemoryNumpyWordCharDataset() dataset.parse( path=text_file, word_dict=dictionary, char_dict=char_dictionary, embed_bytes=embed_bytes, reverse_order=reverse_order, append_eos=append_eos, ) else: dataset = pytorch_translate_data.InMemoryNumpyDataset() dataset.parse( path=text_file, dictionary=dictionary, reverse_order=reverse_order, append_eos=append_eos, already_numberized=already_numberized, ) dataset.save(output_path) return output_path
def binarize_text_file_multilingual( corpus_configs: List[pytorch_translate_data.MultilingualCorpusConfig], output_path: str, append_eos: bool, reverse_order: bool, prepend_language_id: bool, use_char_data: bool = False, embed_bytes: bool = False, already_numberized: bool = False, ) -> str: output_path = maybe_generate_temp_file_path(output_path) if use_char_data: dataset = char_data.InMemoryNumpyWordCharDataset() dataset.parse_multilingual( corpus_configs, reverse_order=reverse_order, append_eos=append_eos, embed_bytes=embed_bytes, prepend_language_id=prepend_language_id, already_numberized=already_numberized, ) else: dataset = pytorch_translate_data.InMemoryNumpyDataset() dataset.parse_multilingual( corpus_configs, append_eos=append_eos, reverse_order=reverse_order, prepend_language_id=prepend_language_id, already_numberized=already_numberized, ) dataset.save(output_path) return output_path
def create_dummy_binarized_dataset(num_sentences=13, min_length=5, max_length=10, append_eos=False): index_sequences = [] for _ in range(num_sentences): sequence = [ np.random.randint(100, 103) for _ in range(np.random.randint(min_length, max_length + 1)) ] if append_eos: sequence.append(vocab_constants.EOS_ID) index_sequences.append(sequence) dataset = pytorch_translate_data.InMemoryNumpyDataset() dataset.load_from_sequences(index_sequences) return dataset
def test_subsample_dataset(self): """ Test the InMemoryNumpyDataset.subsample() method, ensuring that the examples produced by the dataset are correctly permuted according to the indices argument. """ trg_dataset = data.InMemoryNumpyDataset() trg_dataset.parse(self.trg_txt, self.d, reverse_order=False, append_eos=True) indices = np.random.permutation(len(trg_dataset))[:2] token_samples = [trg_dataset[i] for i in indices] trg_dataset.subsample(indices) for i in range(2): assert all(trg_dataset[i].numpy() == token_samples[i].numpy())
def test_parse_oversampling(self): dataset = data.InMemoryNumpyDataset() factors = [(1, 0), (3, 2), (4, 4)] for o1, o2 in factors: corpora = [ data.MultilingualCorpusConfig( dialect_id=None, data_file=self.trg_txt, dict=self.d, oversampling=o1, ), data.MultilingualCorpusConfig( dialect_id=None, data_file=self.trg_txt, dict=self.d, oversampling=o2, ), ] dataset.parse_multilingual(corpora) self.assertEqual((o1 + o2) * self.num_sentences, len(dataset))
def _generate_score(models, args, task, dataset): use_cuda = torch.cuda.is_available() and not args.cpu # Load ensemble if not args.quiet: print("| loading model(s) from {}".format(", ".join(args.path.split(":")))) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=True, ) translator = build_sequence_generator(args, task, models) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) print("seed number is" + str(args.max_examples_to_evaluate_seed)) if args.max_examples_to_evaluate > 0: pytorch_translate_data.subsample_pair_dataset( dataset, args.max_examples_to_evaluate, args.max_examples_to_evaluate_seed ) # Keep track of translations # Initialize with empty translations # and zero probs scores translated_sentences = [""] * len(dataset) translated_scores = [0.0] * len(dataset) hypos_list = [] collect_output_hypos = getattr(args, "output_hypos_binary_path", False) if collect_output_hypos: output_hypos_token_arrays = [None] * len(dataset) # Generate and compute BLEU score dst_dict = task.target_dictionary if args.sacrebleu: scorer = bleu.SacrebleuScorer() else: scorer = bleu.Scorer(dst_dict.pad(), dst_dict.eos(), dst_dict.unk()) itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) oracle_scorer = None if args.report_oracle_bleu: oracle_scorer = bleu.Scorer(dst_dict.pad(), dst_dict.eos(), dst_dict.unk()) rescorer = None num_sentences = 0 translation_samples = [] translation_info_list = [] with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() gen_timer = StopwatchMeter() translations = translator.generate_batched_itr( t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, cuda=use_cuda, timer=gen_timer, prefix_size=1 if pytorch_translate_data.is_multilingual_many_to_one(args) else 0, ) for trans_info in _iter_translations( args, task, dataset, translations, align_dict, rescorer ): if hasattr(scorer, "add_string"): scorer.add_string(trans_info.target_str, trans_info.hypo_str) else: scorer.add(trans_info.target_tokens, trans_info.hypo_tokens) if oracle_scorer is not None: oracle_scorer.add(trans_info.target_tokens, trans_info.best_hypo_tokens) if getattr(args, "translation_output_file", False): translated_sentences[trans_info.sample_id] = trans_info.hypo_str if getattr(args, "hypotheses_export_path", False): hypos_list.append(trans_info.hypos) if collect_output_hypos: output_hypos_token_arrays[ trans_info.sample_id ] = trans_info.best_hypo_tokens if args.translation_info_export_path is not None: # Strip expensive data from hypotheses before saving hypos = [ {k: v for k, v in hypo.items() if k in ["tokens", "score"]} for hypo in trans_info.hypos ] # Make sure everything is on cpu before exporting hypos = [ {"score": hypo["score"], "tokens": hypo["tokens"].cpu()} for hypo in hypos ] translation_info_list.append( { "src_tokens": trans_info.src_tokens.cpu(), "target_tokens": trans_info.target_tokens, "hypos": hypos, } ) translation_samples.append( collections.OrderedDict( { "sample_id": trans_info.sample_id.item(), "src_str": trans_info.src_str, "target_str": trans_info.target_str, "hypo_str": trans_info.hypo_str, } ) ) wps_meter.update(trans_info.src_tokens.size(0)) t.log({"wps": round(wps_meter.avg)}) num_sentences += 1 # If applicable, save collected hypothesis tokens to binary output file if collect_output_hypos: output_dataset = pytorch_translate_data.InMemoryNumpyDataset() output_dataset.load_from_sequences(output_hypos_token_arrays) output_dataset.save(args.output_hypos_binary_path) if args.output_source_binary_path: dataset.src.save(args.output_source_binary_path) if args.translation_info_export_path is not None: f = open(args.translation_info_export_path, "wb") pickle.dump(translation_info_list, f) f.close() # If applicable, save the translations and hypos to the output file # For eg. external evaluation if getattr(args, "translation_output_file", False): with open(args.translation_output_file, "w") as out_file: for hypo_str in translated_sentences: print(hypo_str, file=out_file) if getattr(args, "hypotheses_export_path", False): with open(args.hypotheses_export_path, "w") as out_file: for hypos in hypos_list: for hypo in hypos: print( task.tgt_dict.string( hypo["tokens"], bpe_symbol=args.remove_bpe ), file=out_file, ) if getattr(args, "translation_probs_file", False): with open(args.translation_probs_file, "w") as out_file: for hypo_score in translated_scores: print(np.exp(hypo_score), file=out_file) if oracle_scorer is not None: print(f"| Oracle BLEU (best hypo in beam): {oracle_scorer.result_string()}") return scorer, num_sentences, gen_timer, translation_samples