def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ dataset_map = OrderedDict() for lang in self.langs2id.keys(): # Datasets are expected to be in "split.lang" format (Eg: train.en) language_split = "{}.{}".format(split, lang) block_dataset, sizes = self._load_single_lang_dataset( split=language_split, epoch=epoch) dataset_map[lang] = MaskedLMDataset( dataset=block_dataset, sizes=sizes, vocab=self.dictionary, pad_idx=self.dictionary.pad(), mask_idx=self.dictionary.mask(), classif_token_idx=self.dictionary.eos(), sep_token_idx=self.dictionary.eos(), shuffle=getattr(self.args, "shuffle", False), has_pairs=False, segment_id=self.langs2id[lang], seed=self.seed, ) self.datasets[split] = MultiCorpusSampledDataset(dataset_map) logger.info("{} {} {} examples".format( utils.split_paths(self.args.data)[epoch - 1], split, len(self.datasets[split]), ))
def setup_task(cls, cfg: TranslationConfig, **kwargs): """Setup the task (e.g., load dictionaries). Args: args (argparse.Namespace): parsed command-line arguments """ paths = utils.split_paths(cfg.data) assert len(paths) > 0 # find language pair automatically if cfg.source_lang is None or cfg.target_lang is None: cfg.source_lang, cfg.target_lang = data_utils.infer_language_pair( paths[0]) if cfg.source_lang is None or cfg.target_lang is None: raise Exception( "Could not infer language pair, please provide it explicitly") # load dictionaries src_dict = cls.load_dictionary( os.path.join(paths[0], "dict.{}.txt".format(cfg.source_lang))) tgt_dict = cls.load_dictionary( os.path.join(paths[0], "dict.{}.txt".format(cfg.target_lang))) assert src_dict.pad() == tgt_dict.pad() assert src_dict.eos() == tgt_dict.eos() assert src_dict.unk() == tgt_dict.unk() logger.info("[{}] dictionary: {} types".format(cfg.source_lang, len(src_dict))) logger.info("[{}] dictionary: {} types".format(cfg.target_lang, len(tgt_dict))) return cls(cfg, src_dict, tgt_dict)
def prepare(cls, args, **kargs): cls.update_args(args) sorted_langs = sorted( list({ x for lang_pair in args.lang_pairs for x in lang_pair.split("-") })) if args.source_lang is not None or args.target_lang is not None: training = False else: training = True # load dictionaries dicts = OrderedDict() for lang in sorted_langs: paths = utils.split_paths(args.data) assert len(paths) > 0 dicts[lang] = cls.load_dictionary( os.path.join(paths[0], "dict.{}.txt".format(lang))) if len(dicts) > 0: assert dicts[lang].pad() == dicts[sorted_langs[0]].pad() assert dicts[lang].eos() == dicts[sorted_langs[0]].eos() assert dicts[lang].unk() == dicts[sorted_langs[0]].unk() if args.encoder_langtok is not None or args.decoder_langtok: for lang_to_add in sorted_langs: dicts[lang].add_symbol(_lang_token(lang_to_add)) logger.info("[{}] dictionary: {} types".format( lang, len(dicts[lang]))) return dicts, training
def get_split_num_data_shards(self, split): if split in self._num_shards_dict: return self._num_shards_dict[split] num_shards_dict = {} data_paths, lang_pairs = self.get_data_paths_and_lang_pairs(split) for data_category, paths in data_paths.items(): if data_category not in lang_pairs: continue paths = utils.split_paths(paths) shards_dict = self._get_shard_num_dict(split, paths) lang_dirs = [ lang_pair.split("-") for lang_pair in lang_pairs[data_category] ] lang_dirs = [x if len(x) > 1 else (x[0], x[0]) for x in lang_dirs] for src, tgt in lang_dirs: key = self.get_dataset_key(data_category, src, tgt) if "mono_" in data_category: # monolingual data requires tgt only assert src is None or src == tgt, ( f"error: src={src}, " "tgt={tgt} for data_category={data_category}") num_shards_dict[key] = shards_dict[tgt] else: if f"{src}-{tgt}" in shards_dict: num_shards_dict[key] = shards_dict[f"{src}-{tgt}"] elif f"{tgt}-{src}" in shards_dict: # follow the fairseq tradition to use reversed direction data if it is not available num_shards_dict[key] = shards_dict[f"{tgt}-{src}"] self._num_shards_dict[split] = num_shards_dict logger.info(f"[{split}] num of shards: {num_shards_dict}") return num_shards_dict
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] # infer langcode src, tgt = self.args.source_lang, self.args.target_lang self.datasets[split] = load_langpair_dataset( data_path, split, src, self.src_dict, tgt, self.tgt_dict, combine=combine, dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=getattr(self.args, "max_source_positions", 1024), max_target_positions=getattr(self.args, "max_target_positions", 1024), load_alignments=self.args.load_alignments, prepend_bos=getattr(self.args, "prepend_bos", False), append_source_id=True, )
def setup_task(cls, args, **kwargs): """Setup the task.""" paths = utils.split_paths(args.data) assert len(paths) > 0 dictionary = BertDictionary.load(os.path.join(paths[0], "dict.txt")) logger.info("dictionary: {} types".format(len(dictionary))) return cls(args, dictionary)
def get_split_data_param_list(self, split, epoch, shard_epoch=None): # TODO: to extend with extra datasets and keys and loop over different shard data paths param_list = [] data_paths, lang_pairs = self.get_data_paths_and_lang_pairs(split) logger.info(f"langtoks settings: {self.args.langtoks}") split_num_shards_dict = self.get_split_num_data_shards(split) for data_category, paths in data_paths.items(): if data_category not in lang_pairs: continue paths = utils.split_paths(paths) assert len(paths) > 0 if len(paths) > 1: self._has_sharded_data = True if split != getattr(self.args, "train_subset", None): # if not training data set, use the first shard for valid and test paths = paths[:1] if data_category in self.args.langtoks: lang_tok_spec = self.args.langtoks[data_category] else: # default to None lang_tok_spec = (None, None) # infer langcode lang_dirs = [ lang_pair.split("-") for lang_pair in lang_pairs[data_category] ] lang_dirs = [x if len(x) > 1 else (x[0], x[0]) for x in lang_dirs] for src, tgt in lang_dirs: assert src is not None or data_category == "mono_dae", ( f"error: src={src}, " "tgt={tgt} for data_category={data_category}") # logger.info(f"preparing param for {data_category}: {src} - {tgt}") key = self.get_dataset_key(data_category, src, tgt) data_path = self.get_split_data_path( paths, epoch, shard_epoch, split_num_shards_dict[key]) param_list.append({ "key": key, "data_path": data_path, "split": split, "src": src, "src_dict": self.get_source_dictionary(src) if src and data_category != "mono_dae" else None, "tgt": tgt, "tgt_dict": self.get_target_dictionary(tgt), "data_category": data_category, "langtok_spec": lang_tok_spec, }) return param_list
def load_dataset( self, split: str, epoch=1, combine=False, **kwargs ) -> MonolingualDataset: """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) # each process has its own copy of the raw data (likely to be an np.memmap) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.args.dataset_impl, combine=combine ) if dataset is None: raise FileNotFoundError(f"Dataset not found: {split} ({split_path})") dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, use_plasma_view=self.args.use_plasma_view, split_path=split_path, plasma_path=self.args.plasma_path, ) add_eos_for_other_targets = ( self.args.sample_break_mode is not None and self.args.sample_break_mode != "none" ) self.datasets[split] = MonolingualDataset( dataset=dataset, sizes=dataset.sizes, src_vocab=self.dictionary, tgt_vocab=self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, add_bos_token=self.args.add_bos_token, )
def load_all_dictionaries(cls, args, language_list, load_dictionary, training): dicts = OrderedDict() if args.source_dict is not None: dicts[SRC_DICT_NAME] = load_dictionary(args.source_dict) if args.target_dict is not None: dicts[TGT_DICT_NAME] = load_dictionary(args.target_dict) if training: extra_lang_pairs = (list({ p for _, v in args.extra_lang_pairs.items() for p in v.split(",") }) if args.extra_lang_pairs else []) src_langs_to_load_dicts = sorted({ p.split("-")[0] for p in (args.lang_pairs + extra_lang_pairs) }) tgt_langs_to_load_dicts = sorted({ p.split("-")[1] for p in (args.lang_pairs + extra_lang_pairs) }) else: src_langs_to_load_dicts = [args.source_lang] tgt_langs_to_load_dicts = [args.target_lang] paths = utils.split_paths(args.data) assert len(paths) > 0 def load_dicts(langs_to_load_dicts): for lang in langs_to_load_dicts: dicts[lang] = load_dictionary( os.path.join(paths[0], "dict.{}.txt".format(lang))) if len(dicts) > 0: dict0 = next(iter(dicts.values())) assert dicts[lang].pad() == dict0.pad() assert dicts[lang].eos() == dict0.eos() assert dicts[lang].unk() == dict0.unk() logger.info("[{}] dictionary: {} types".format( lang, len(dicts[lang]))) if args.fixed_dictionary is not None: fixed_dict = load_dictionary(args.fixed_dictionary) dicts = { lang: fixed_dict for lang in src_langs_to_load_dicts + tgt_langs_to_load_dicts } else: if args.source_dict is None: load_dicts(src_langs_to_load_dicts) if args.target_dict is None: load_dicts(tgt_langs_to_load_dicts) return dicts
def setup_dictionary(cls, args, **kwargs): dictionary = None output_dictionary = None if args.data: paths = utils.split_paths(args.data) assert len(paths) > 0 dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) logger.info("dictionary: {} types".format(len(dictionary))) output_dictionary = dictionary if args.output_dictionary_size >= 0: output_dictionary = TruncatedDictionary( dictionary, args.output_dictionary_size ) return (dictionary, output_dictionary)
def _load_single_lang_dataset(self, split, epoch): loaded_datasets = [] paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") path = os.path.join(data_path, split_k) ds = data_utils.load_indexed_dataset(path, self.dictionary, self.args.dataset_impl) if ds is None: if k > 0: break else: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, data_path)) # Since we append each block with the classification_token, # we need to effectively create blocks of length # tokens_per_sample-1 loaded_datasets.append( TokenBlockDataset( ds, ds.sizes, self.args.tokens_per_sample - 1, pad=self.dictionary.pad(), eos=self.dictionary.eos(), )) logger.info("{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1]))) if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) return dataset, sizes
def load_dataset(self, split, epoch=1, **kwargs): """Load a dataset split.""" paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] def language_pair_dataset(lang_pair): src, tgt = lang_pair.split("-") langpair_dataset = load_langpair_dataset( data_path, split, src, self.dicts[src], tgt, self.dicts[tgt], combine=True, dataset_impl=self.args.dataset_impl, upsample_primary=self.args.upsample_primary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, ) return self.alter_dataset_langtok( langpair_dataset, src_eos=self.dicts[src].eos(), src_lang=src, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ) self.datasets[split] = RoundRobinZipDatasets( OrderedDict([(lang_pair, language_pair_dataset(lang_pair)) for lang_pair in self.lang_pairs]), eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang), )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.cfg.data) assert len(paths) > 0 if split != self.cfg.train_subset: # if not training data set, use the first shard for valid and test paths = paths[:1] data_path = paths[(epoch - 1) % len(paths)] # infer langcode src, tgt = self.cfg.source_lang, self.cfg.target_lang self.datasets[split] = load_langpair_dataset( data_path, split, src, self.src_dict, tgt, self.tgt_dict, combine=combine, dataset_impl=self.cfg.dataset_impl, upsample_primary=self.cfg.upsample_primary, left_pad_source=self.cfg.left_pad_source, left_pad_target=self.cfg.left_pad_target, max_source_positions=self.cfg.max_source_positions, max_target_positions=self.cfg.max_target_positions, load_alignments=self.cfg.load_alignments, truncate_source=self.cfg.truncate_source, num_buckets=self.cfg.num_batch_buckets, shuffle=(split != "test"), pad_to_multiple=self.cfg.required_seq_len_multiple, )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) dataset = StripTokenDataset(dataset, self.dictionary.eos()) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 2, # one less for <s> and one for </s> pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, document_sep_len=0, ) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) dataset = AppendTokenDataset(dataset, self.source_dictionary.eos()) mask_whole_words = (get_whole_word_mask(self.args, self.source_dictionary) if self.args.mask_length != "subword" else None) self.datasets[split] = DenoisingDataset( dataset, dataset.sizes, self.dictionary, self.mask_idx, mask_whole_words, shuffle=self.args.shuffle_instance, seed=self.seed, args=self.args, ) logger.info( "Split: {0}, Loaded {1} samples of denoising_dataset".format( split, len(self.datasets[split]), ))
def main(cfg: FairseqConfig): if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) start_time = time.time() total_translate_time = 0 utils.import_user_module(cfg.common) if cfg.interactive.buffer_size < 1: cfg.interactive.buffer_size = 1 if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: cfg.dataset.batch_size = 1 assert ( not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam ), "--sampling requires --nbest to be equal to --beam" assert ( not cfg.dataset.batch_size or cfg.dataset.batch_size <= cfg.interactive.buffer_size ), "--batch-size cannot be larger than --buffer-size" logger.info(cfg) # Fix seed for stochastic decoding if cfg.common.seed is not None and not cfg.generation.no_seed_provided: np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Setup task, e.g., translation task = tasks.setup_task(cfg.task) # Load ensemble overrides = ast.literal_eval(cfg.common_eval.model_overrides) logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, suffix=cfg.checkpoint.checkpoint_suffix, strict=(cfg.checkpoint.checkpoint_shard_count == 1), num_shards=cfg.checkpoint.checkpoint_shard_count, ) # Set dictionaries src_dict = task.source_dictionary tgt_dict = task.target_dictionary # Optimize ensemble for generation for model in models: if model is None: continue if cfg.common.fp16: model.half() if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(cfg) # Initialize generator generator = task.build_generator(models, cfg.generation) # Handle tokenization and BPE tokenizer = task.build_tokenizer(cfg.tokenizer) bpe = task.build_bpe(cfg.bpe) def encode_fn(x): if tokenizer is not None: x = tokenizer.encode(x) if bpe is not None: x = bpe.encode(x) return x def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(cfg.generation.replace_unk) max_positions = utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] ) if cfg.generation.constraints: logger.warning( "NOTE: Constrained decoding currently assumes a shared subword vocabulary." ) if cfg.interactive.buffer_size > 1: logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size) logger.info("NOTE: hypothesis and token scores are output in base 2") logger.info("Type the input sentence and press return:") start_id = 0 for inputs in buffered_read(cfg.interactive.input, cfg.interactive.buffer_size): results = [] for batch in make_batches(inputs, cfg, task, max_positions, encode_fn): bsz = batch.src_tokens.size(0) src_tokens = batch.src_tokens src_lengths = batch.src_lengths constraints = batch.constraints if use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() if constraints is not None: constraints = constraints.cuda() sample = { "net_input": { "src_tokens": src_tokens, "src_lengths": src_lengths, }, } translate_start_time = time.time() translations = task.inference_step( generator, models, sample, constraints=constraints ) translate_time = time.time() - translate_start_time total_translate_time += translate_time list_constraints = [[] for _ in range(bsz)] if cfg.generation.constraints: list_constraints = [unpack_constraints(c) for c in constraints] for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) constraints = list_constraints[i] results.append( ( start_id + id, src_tokens_i, hypos, { "constraints": constraints, "time": translate_time / len(translations), }, ) ) # sort output to match input order for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): src_str = '' if src_dict is not None: src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) print("S-{}\t{}".format(id_, src_str)) print("W-{}\t{:.3f}\tseconds".format(id_, info["time"])) for constraint in info["constraints"]: print( "C-{}\t{}".format( id_, tgt_dict.string(constraint, cfg.common_eval.post_process) ) ) # Process top predictions for hypo in hypos[: min(len(hypos), cfg.generation.nbest)]: hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo["tokens"].int().cpu(), src_str=src_str, alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=cfg.common_eval.post_process, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) detok_hypo_str = decode_fn(hypo_str) score = hypo["score"] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) print("H-{}\t{}\t{}".format(id_, score, hypo_str)) # detokenized hypothesis print("D-{}\t{}\t{}".format(id_, score, detok_hypo_str)) print( "P-{}\t{}".format( id_, " ".join( map( lambda x: "{:.4f}".format(x), # convert from base e to base 2 hypo["positional_scores"].div_(math.log(2)).tolist(), ) ), ) ) if cfg.generation.print_alignment: alignment_str = " ".join( ["{}-{}".format(src, tgt) for src, tgt in alignment] ) print("A-{}\t{}".format(id_, alignment_str)) # update running id_ counter start_id += len(inputs) logger.info( "Total time: {:.3f} seconds; translation time: {:.3f}".format( time.time() - start_time, total_translate_time ) )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] split_path = os.path.join(data_path, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) dataset = maybe_shorten_dataset( dataset, split, self.args.shorten_data_split_list, self.args.shorten_method, self.args.tokens_per_sample, self.args.seed, ) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets mask_whole_words = (get_whole_word_mask(self.args, self.source_dictionary) if self.args.mask_whole_words else None) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, mask_multiple_length=self.args.mask_multiple_length, mask_stdev=self.args.mask_stdev, ) with data_utils.numpy_seed(self.args.seed): shuffle = np.random.permutation(len(src_dataset)) self.datasets[split] = SortDataset( NestedDictionaryDataset( { "id": IdDataset(), "net_input": { "src_tokens": RightPadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), ), "src_lengths": NumelDataset(src_dataset, reduce=False), }, "target": RightPadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), ), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_dataset, reduce=True), }, sizes=[src_dataset.sizes], ), sort_order=[ shuffle, src_dataset.sizes, ], )
def load_dataset(self, split, epoch=1, combine=False): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ loaded_datasets = [] paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] logger.info("data_path", data_path) for k in itertools.count(): split_k = split + (str(k) if k > 0 else "") path = os.path.join(data_path, split_k) ds = indexed_dataset.make_dataset( path, impl=self.args.dataset_impl, fix_lua_indexing=True, dictionary=self.dictionary, ) if ds is None: if k > 0: break else: raise FileNotFoundError( "Dataset not found: {} ({})".format(split, data_path) ) with data_utils.numpy_seed(self.seed + k): loaded_datasets.append( BlockPairDataset( ds, self.dictionary, ds.sizes, self.args.tokens_per_sample, break_mode=self.args.break_mode, doc_break_size=1, ) ) logger.info( "{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1])) ) if not combine: break if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) self.datasets[split] = MaskedLMDataset( dataset=dataset, sizes=sizes, vocab=self.dictionary, pad_idx=self.dictionary.pad(), mask_idx=self.dictionary.mask(), classif_token_idx=self.dictionary.cls(), sep_token_idx=self.dictionary.sep(), shuffle=self.args.shuffle_dataset, seed=self.seed, )
def load_dataset(self, split, epoch=1, **kwargs): """Load a dataset split.""" paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] def split_exists(split, src, tgt, lang): if src is not None: filename = os.path.join( data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) else: filename = os.path.join( data_path, "{}.{}-None.{}".format(split, src, tgt)) return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl) def load_indexed_dataset(path, dictionary): return data_utils.load_indexed_dataset(path, dictionary, self.args.dataset_impl) # load parallel datasets src_datasets, tgt_datasets = {}, {} if (self.lambda_parallel > 0.0 or self.lambda_parallel_steps is not None or not split.startswith("train")): for lang_pair in self.lang_pairs: src, tgt = lang_pair.split("-") if split_exists(split, src, tgt, src): prefix = os.path.join(data_path, "{}.{}-{}.".format(split, src, tgt)) elif split_exists(split, tgt, src, src): prefix = os.path.join(data_path, "{}.{}-{}.".format(split, tgt, src)) else: continue src_datasets[lang_pair] = load_indexed_dataset( prefix + src, self.dicts[src]) tgt_datasets[lang_pair] = load_indexed_dataset( prefix + tgt, self.dicts[tgt]) logger.info("parallel-{} {} {} examples".format( data_path, split, len(src_datasets[lang_pair]))) if len(src_datasets) == 0: raise FileNotFoundError("Dataset not found: {} ({})".format( split, data_path)) # back translation datasets backtranslate_datasets = {} if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and split.startswith("train"): for lang_pair in self.lang_pairs: src, tgt = lang_pair.split("-") if not split_exists(split, tgt, None, tgt): raise FileNotFoundError( "Dataset not found: backtranslation {} ({})".format( split, data_path)) filename = os.path.join( data_path, "{}.{}-None.{}".format(split, tgt, tgt)) dataset = load_indexed_dataset(filename, self.dicts[tgt]) lang_pair_dataset_tgt = LanguagePairDataset( dataset, dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) lang_pair_dataset = LanguagePairDataset( dataset, dataset.sizes, src_dict=self.dicts[src], tgt=dataset, tgt_sizes=dataset.sizes, tgt_dict=self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ) backtranslate_datasets[lang_pair] = BacktranslationDataset( tgt_dataset=self.alter_dataset_langtok( lang_pair_dataset_tgt, src_eos=self.dicts[tgt].eos(), src_lang=tgt, tgt_lang=src, ), backtranslation_fn=self.backtranslators[lang_pair], src_dict=self.dicts[src], tgt_dict=self.dicts[tgt], output_collater=self.alter_dataset_langtok( lang_pair_dataset=lang_pair_dataset, src_eos=self.dicts[src].eos(), src_lang=src, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ).collater, ) logger.info("backtranslate-{}: {} {} {} examples".format( tgt, data_path, split, len(backtranslate_datasets[lang_pair]), )) self.backtranslate_datasets[ lang_pair] = backtranslate_datasets[lang_pair] # denoising autoencoder noising_datasets = {} if (self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None) and split.startswith("train"): for lang_pair in self.lang_pairs: _, tgt = lang_pair.split("-") if not split_exists(split, tgt, None, tgt): continue filename = os.path.join( data_path, "{}.{}-None.{}".format(split, tgt, tgt)) tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt]) tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt]) noising_dataset = NoisingDataset( tgt_dataset1, self.dicts[tgt], seed=1, max_word_shuffle_distance=self.args. max_word_shuffle_distance, word_dropout_prob=self.args.word_dropout_prob, word_blanking_prob=self.args.word_blanking_prob, ) noising_datasets[lang_pair] = self.alter_dataset_langtok( LanguagePairDataset( noising_dataset, tgt_dataset1.sizes, self.dicts[tgt], tgt_dataset2, tgt_dataset2.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ), src_eos=self.dicts[tgt].eos(), src_lang=tgt, tgt_eos=self.dicts[tgt].eos(), tgt_lang=tgt, ) logger.info("denoising-{}: {} {} {} examples".format( tgt, data_path, split, len(noising_datasets[lang_pair]), )) def language_pair_dataset(lang_pair): src, tgt = lang_pair.split("-") src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[ lang_pair] return self.alter_dataset_langtok( LanguagePairDataset( src_dataset, src_dataset.sizes, self.dicts[src], tgt_dataset, tgt_dataset.sizes, self.dicts[tgt], left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, ), self.dicts[src].eos(), src, self.dicts[tgt].eos(), tgt, ) self.datasets[split] = RoundRobinZipDatasets( OrderedDict( [(lang_pair, language_pair_dataset(lang_pair)) for lang_pair in src_datasets.keys()] + [(_get_bt_dataset_key(lang_pair), dataset) for lang_pair, dataset in backtranslate_datasets.items()] + [(_get_denoising_dataset_key(lang_pair), dataset) for lang_pair, dataset in noising_datasets.items()]), eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang), )
def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] languages = sorted(name for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name))) logger.info("Training on {0} languages: {1}".format( len(languages), languages)) logger.info("Language to id mapping: ", {lang: id for id, lang in enumerate(languages)}) mask_whole_words = self._get_whole_word_mask() lang_datasets = [] for lang_id, language in enumerate(languages): split_path = os.path.join(data_path, language, split) dataset = data_utils.load_indexed_dataset( split_path, self.source_dictionary, self.args.dataset_impl, combine=combine, ) if dataset is None: raise FileNotFoundError("Dataset not found: {} ({})".format( split, split_path)) # create continuous blocks of tokens dataset = TokenBlockDataset( dataset, dataset.sizes, self.args.tokens_per_sample - 1, # one less for <s> pad=self.source_dictionary.pad(), eos=self.source_dictionary.eos(), break_mode=self.args.sample_break_mode, ) logger.info("loaded {} blocks from: {}".format( len(dataset), split_path)) # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT) dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, self.source_dictionary, pad_idx=self.source_dictionary.pad(), mask_idx=self.mask_idx, seed=self.args.seed, mask_prob=self.args.mask_prob, leave_unmasked_prob=self.args.leave_unmasked_prob, random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, ) lang_dataset = NestedDictionaryDataset( { "net_input": { "src_tokens": PadDataset( src_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), "src_lengths": NumelDataset(src_dataset, reduce=False), }, "target": PadDataset( tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False, ), "nsentences": NumSamplesDataset(), "ntokens": NumelDataset(src_dataset, reduce=True), "lang_id": RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]), }, sizes=[src_dataset.sizes], ) lang_datasets.append(lang_dataset) dataset_lengths = np.array( [len(d) for d in lang_datasets], dtype=float, ) logger.info("loaded total {} blocks for all languages".format( dataset_lengths.sum(), )) if split == self.args.train_subset: # For train subset, additionally up or down sample languages. sample_probs = self._get_sample_prob(dataset_lengths) logger.info( "Sample probability by language: ", { lang: "{0:.4f}".format(sample_probs[id]) for id, lang in enumerate(languages) }, ) size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths logger.info( "Up/Down Sampling ratio by language: ", { lang: "{0:.2f}".format(size_ratio[id]) for id, lang in enumerate(languages) }, ) resampled_lang_datasets = [ ResamplingDataset( lang_datasets[i], size_ratio=size_ratio[i], seed=self.args.seed, epoch=epoch, replace=size_ratio[i] >= 1.0, ) for i, d in enumerate(lang_datasets) ] dataset = ConcatDataset(resampled_lang_datasets) else: dataset = ConcatDataset(lang_datasets) lang_splits = [split] for lang_id, lang_dataset in enumerate(lang_datasets): split_name = split + "_" + languages[lang_id] lang_splits.append(split_name) self.datasets[split_name] = lang_dataset # [TODO]: This is hacky for now to print validation ppl for each # language individually. Maybe need task API changes to allow it # in more generic ways. if split in self.args.valid_subset: self.args.valid_subset = self.args.valid_subset.replace( split, ",".join(lang_splits)) with data_utils.numpy_seed(self.args.seed + epoch): shuffle = np.random.permutation(len(dataset)) self.datasets[split] = SortDataset( dataset, sort_order=[ shuffle, dataset.sizes, ], )
def _main(cfg: DictConfig, output_file): logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=output_file, ) logger = logging.getLogger("fairseq_cli.generate") utils.import_user_module(cfg.common) if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: cfg.dataset.max_tokens = 12000 logger.info(cfg) # Fix seed for stochastic decoding if cfg.common.seed is not None and not cfg.generation.no_seed_provided: np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Load dataset splits task = tasks.setup_task(cfg.task) # Set dictionaries try: src_dict = getattr(task, "source_dictionary", None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary overrides = ast.literal_eval(cfg.common_eval.model_overrides) # START ba_st_ch code eval_file_name, checkpoint = get_evaluation_file_name( str(cfg.common_eval.path)) print(eval_file_name) eval_manifest = {c: [] for c in MANIFEST_COLUMNS} eval_file_name = resource_path + eval_file_name # END ba_st_ch code # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, saved_cfg = checkpoint_utils.load_model_ensemble( utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, suffix=cfg.checkpoint.checkpoint_suffix, strict=(cfg.checkpoint.checkpoint_shard_count == 1), num_shards=cfg.checkpoint.checkpoint_shard_count, ) # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) if cfg.generation.lm_path is not None: overrides["data"] = cfg.task.data try: lms, _ = checkpoint_utils.load_model_ensemble( [cfg.generation.lm_path], arg_overrides=overrides, task=None) except: logger.warning( f"Failed to load language model! Please make sure that the language model dict is the same " f"as target dict and is located in the data dir ({cfg.task.data})" ) raise assert len(lms) == 1 else: lms = [None] # Optimize ensemble for generation for model in chain(models, lms): if model is None: continue if cfg.common.fp16: model.half() if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(cfg) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(cfg.generation.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(cfg.dataset.gen_subset), max_tokens=cfg.dataset.max_tokens, max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models]), ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, seed=cfg.common.seed, num_shards=cfg.distributed_training.distributed_world_size, shard_id=cfg.distributed_training.distributed_rank, num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) # Initialize generator gen_timer = StopwatchMeter() extra_gen_cls_kwargs = { "lm_model": lms[0], "lm_weight": cfg.generation.lm_weight } generator = task.build_generator(models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs) # Handle tokenization and BPE tokenizer = task.build_tokenizer(cfg.tokenizer) bpe = task.build_bpe(cfg.bpe) def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x scorer = scoring.build_scorer(cfg.scoring, tgt_dict) num_sentences = 0 has_target = True wps_meter = TimeMeter() for sample in progress: sample = utils.move_to_cuda(sample) if use_cuda else sample if "net_input" not in sample: continue prefix_tokens = None if cfg.generation.prefix_size > 0: prefix_tokens = sample["target"][:, :cfg.generation.prefix_size] constraints = None if "constraints" in sample: constraints = sample["constraints"] gen_timer.start() hypos = task.inference_step( generator, models, sample, prefix_tokens=prefix_tokens, constraints=constraints, ) num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample["id"].tolist()): has_target = sample["target"] is not None # Remove padding if "src_tokens" in sample["net_input"]: src_tokens = utils.strip_pad( sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()) else: src_tokens = None target_tokens = None if has_target: target_tokens = (utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu()) # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset( cfg.dataset.gen_subset).src.get_original_text(sample_id) target_str = task.dataset( cfg.dataset.gen_subset).tgt.get_original_text(sample_id) else: if src_dict is not None: src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) else: src_str = "" if has_target: target_str = tgt_dict.string( target_tokens, cfg.common_eval.post_process, escape_unk=True, extra_symbols_to_ignore= get_symbols_to_strip_from_output(generator), ) # START ba_st_ch code print("-------------") test_dataset = corpus_path + str(cfg.dataset.gen_subset) + ".tsv" df = pd.read_table(test_dataset) id = df.loc[[sample_id], ["id"]] print(id) id = str(id).partition('\n')[2] print("id: ", id) print("-------------") # END ba_st_ch code src_str = decode_fn(src_str) if has_target: target_str = decode_fn(target_str) if not cfg.common_eval.quiet: if src_dict is not None: print("S-{}\t{}".format(sample_id, src_str), file=output_file) if has_target: print("T-{}\t{}".format(sample_id, target_str), file=output_file) # Process top predictions for j, hypo in enumerate(hypos[i][:cfg.generation.nbest]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo["tokens"].int().cpu(), src_str=src_str, alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=cfg.common_eval.post_process, extra_symbols_to_ignore=get_symbols_to_strip_from_output( generator), ) detok_hypo_str = decode_fn(hypo_str) # START ba_st_ch code eval_manifest["id"].append(id) eval_manifest["target_txt"].append(target_str) eval_manifest["prediction"].append(detok_hypo_str) # END ba_st_ch code if not cfg.common_eval.quiet: score = hypo["score"] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) print( "H-{}\t{}\t{}".format(sample_id, score, hypo_str), file=output_file, ) # detokenized hypothesis print( "D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str), file=output_file, ) print( "P-{}\t{}".format( sample_id, " ".join( map( lambda x: "{:.4f}".format(x), # convert from base e to base 2 hypo["positional_scores"].div_(math.log(2) ).tolist(), )), ), file=output_file, ) if cfg.generation.print_alignment == "hard": print( "A-{}\t{}".format( sample_id, " ".join([ "{}-{}".format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment ]), ), file=output_file, ) if cfg.generation.print_alignment == "soft": print( "A-{}\t{}".format( sample_id, " ".join([ ",".join(src_probs) for src_probs in alignment ]), ), file=output_file, ) if cfg.generation.print_step: print( "I-{}\t{}".format(sample_id, hypo["steps"]), file=output_file, ) if cfg.generation.retain_iter_history: for step, h in enumerate(hypo["history"]): _, h_str, _ = utils.post_process_prediction( hypo_tokens=h["tokens"].int().cpu(), src_str=src_str, alignment=None, align_dict=None, tgt_dict=tgt_dict, remove_bpe=None, ) print( "E-{}_{}\t{}".format(sample_id, step, h_str), file=output_file, ) # Score only the top hypothesis if has_target and j == 0: if align_dict is not None or cfg.common_eval.post_process is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line( target_str, add_if_not_exist=True) hypo_tokens = tgt_dict.encode_line( detok_hypo_str, add_if_not_exist=True) if hasattr(scorer, "add_string"): scorer.add_string(target_str, detok_hypo_str) else: scorer.add(target_tokens, hypo_tokens) wps_meter.update(num_generated_tokens) progress.log({"wps": round(wps_meter.avg)}) num_sentences += (sample["nsentences"] if "nsentences" in sample else sample["id"].numel()) logger.info("NOTE: hypothesis and token scores are output in base 2") logger.info( "Translated {:,} sentences ({:,} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)" .format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1.0 / gen_timer.avg, )) if has_target: if cfg.bpe and not cfg.generation.sacrebleu: if cfg.common_eval.post_process: logger.warning( "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization" ) else: logger.warning( "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization" ) # use print to be consistent with other main outputs: S-, H-, T-, D- and so on print( "Generate {} with beam={}: {}".format(cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string()), file=output_file, ) # START ba_st_ch code bleu_score = scorer.result_string().rsplit(' = ')[1].split(" ")[0] # print(bleu_score) # print(float(bleu_score) / 2) f.write(checkpoint + ";" + bleu_score + '\n') generate_manifest(eval_manifest, eval_file_name) # END ba_st_ch code return scorer