def indexed_dataset(path, dictionary): if self.args.raw_text: return IndexedRawTextDataset(path, dictionary) elif IndexedDataset.exists(path): if self.args.lazy_load: return IndexedDataset(path, fix_lua_indexing=True) else: return IndexedCachedDataset(path, fix_lua_indexing=True) return None
def indexed_dataset(path, dictionary, copy_ext_dict=False, src_dataset=None): if self.args.raw_text: return IndexedRawTextDataset(path, dictionary, copy_ext_dict=copy_ext_dict, src_dataset=src_dataset) elif IndexedDataset.exists(path): if self.args.lazy_load: return IndexedDataset(path, fix_lua_indexing=True) else: return IndexedCachedDataset(path, fix_lua_indexing=True) return None
def split_exists(split, src, tgt, lang): filename = os.path.join( self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang)) if self.args.raw_text and IndexedRawTextDataset.exists(filename): return True elif not self.args.raw_text and IndexedCachedDataset.exists( filename): return True return False
def indexed_dataset(path, dictionary, ex_dict=None, is_tgt=False): if self.args.segment: #if self.args.raw_text: return IndexedRawTextSegDataset(path, dictionary, ex_dict, is_tgt) else: if self.args.raw_text: return IndexedRawTextDataset(path, dictionary) elif IndexedDataset.exists(path): return IndexedCachedDataset(path, fix_lua_indexing=True) return None
def indexed_dataset(path): assert IndexedDataset.exists(path), f'IndexedDataset.exists({path})' # if self.args.raw_text: # return IndexedRawTextDataset(path, dictionary) # elif IndexedDataset.exists(path): # if self.args.lazy_load: # return IndexedDataset(path, fix_lua_indexing=True) # else: # return IndexedCachedDataset(path, fix_lua_indexing=True) # return None return IndexedCachedDataset(path, fix_lua_indexing=True)
def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ loaded_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') path = os.path.join(self.args.data, split_k) if self.args.raw_text and IndexedRawTextDataset.exists(path): ds = IndexedRawTextDataset(path, self.dictionary) elif not self.args.raw_text and IndexedDataset.exists(path): if self.args.lazy_load: ds = IndexedDataset(path, fix_lua_indexing=True) else: ds = IndexedCachedDataset(path, fix_lua_indexing=True) else: if k > 0: break else: raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) loaded_datasets.append( TokenBlockDataset( ds, ds.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True, ) ) print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) if not combine: break if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none' self.datasets[split] = MonolingualDataset( dataset, sizes, self.dictionary, self.output_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True, targets=self.targets, )
def load_dataset(self, split, combine=False): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ dataset_map = OrderedDict() for lang in self.langs2id.keys(): if self.default_key is None: self.default_key = lang # Datasets are expected to be in "split.lang" format (Eg: train.en) language_split = '{}.{}'.format(split, lang) path = os.path.join(self.args.data, language_split) if self.args.raw_text and IndexedRawTextDataset.exists(path): ds = IndexedRawTextDataset(path, self.dictionary) elif not self.args.raw_text and IndexedDataset.exists(path): if self.args.lazy_load: ds = IndexedDataset(path, fix_lua_indexing=True) else: ds = IndexedCachedDataset(path, fix_lua_indexing=True) else: raise FileNotFoundError('Dataset not found: {} ({})'.format( language_split, self.args.data)) # Since we append each block with the classification_token, # we need to effectively create blocks of length # tokens_per_sample-1 block_dataset = TokenBlockDataset( dataset=ds, sizes=ds.sizes, block_size=self.args.tokens_per_sample - 1, pad=self.dictionary.pad(), eos=self.dictionary.eos()) dataset_map[lang] = MaskedLMDataset( dataset=block_dataset, sizes=block_dataset.sizes, vocab=self.dictionary, pad_idx=self.dictionary.pad(), mask_idx=self.dictionary.mask(), classif_token_idx=self.dictionary.eos(), sep_token_idx=self.dictionary.eos(), shuffle=getattr(self.args, 'shuffle', False), has_pairs=False, segment_id=self.langs2id[lang], seed=self.seed, ) self.datasets[split] = MultiCorpusSampledDataset( dataset_map, default_key=self.default_key) print('| {} {} {} examples'.format(self.args.data, split, len(self.datasets[split])))
def indexed_dataset(path, dictionary, cached=True, audio=False): if self.args.raw_text: return IndexedRawTextDataset(path, dictionary) elif IndexedDataset.exists(path): if cached: return IndexedCachedDataset(path, fix_lua_indexing=True, audio=audio) else: return IndexedDataset(path, fix_lua_indexing=True, audio=audio) return None
def _load_single_lang_dataset(self, split): loaded_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') path = os.path.join(self.args.data, split_k) if self.args.raw_text and IndexedRawTextDataset.exists(path): ds = IndexedRawTextDataset(path, self.dictionary) elif not self.args.raw_text and IndexedDataset.exists(path): if self.args.lazy_load: ds = IndexedDataset(path, fix_lua_indexing=True) else: ds = IndexedCachedDataset(path, fix_lua_indexing=True) else: if k > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format( split, self.args.data)) # Since we append each block with the classification_token, # we need to effectively create blocks of length # tokens_per_sample-1 loaded_datasets.append( TokenBlockDataset( ds, ds.sizes, self.args.tokens_per_sample - 1, pad=self.dictionary.pad(), eos=self.dictionary.eos(), )) print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) return dataset, sizes
def indexed_dataset(path, dictionary): return IndexedCachedDataset(path, fix_lua_indexing=True)
def load_dataset(self, split, combine=False): """ Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ loaded_datasets = [] for k in itertools.count(): split_k = split + (str(k) if k > 0 else '') path = os.path.join(self.args.data, split_k) if self.args.raw_text and IndexedRawTextDataset.exists(path): ds = IndexedRawTextDataset(path, self.dictionary) elif not self.args.raw_text and IndexedDataset.exists(path): if self.args.lazy_load: ds = IndexedDataset(path, fix_lua_indexing=True) else: ds = IndexedCachedDataset(path, fix_lua_indexing=True) else: if k > 0: break else: raise FileNotFoundError( 'Dataset not found: {} ({})'.format( split, self.args.data)) with data_utils.numpy_seed(self.seed + k): loaded_datasets.append( BlockPairDataset( ds, self.dictionary, ds.sizes, self.args.tokens_per_sample, break_mode=self.args.break_mode, )) logger.info('{} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) if not combine: break if len(loaded_datasets) == 1: dataset = loaded_datasets[0] sizes = dataset.sizes else: dataset = ConcatDataset(loaded_datasets) sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) self.datasets[split] = MaskedLMDataset( dataset=dataset, sizes=sizes, vocab=self.dictionary, pad_idx=self.dictionary.pad(), mask_idx=self.dictionary.mask(), classif_token_idx=self.dictionary.cls(), sep_token_idx=self.dictionary.sep(), shuffle=False, seed=self.seed, )
def indexed_dataset(path): assert IndexedCachedDataset.exists( path), f'IndexedCachedDataset.exists({path})' return IndexedCachedDataset(path, fix_lua_indexing=True)
def main(args, checkpoint_name="best"): assert args.path is not None, '--path required for generation!' assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' assert args.replace_unk is None or args.raw_text, \ '--replace-unk requires a raw text dataset (--raw-text)' if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 print(args) use_cuda = torch.cuda.is_available() and not args.cpu torch.manual_seed(args.seed) # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset)))) args.taskobj = task # Set dictionaries #src_dict = task.source_dictionary tgt_dict = task.target_dictionary dict = tgt_dict # Load decoding strategy strategy = strategies.setup_strategy(args) # Load ensemble if args.path.startswith("nsml://"): print("| loading nsml checkpoint", args.path) import nsml session = args.path.replace("nsml://", "") model = task.build_model(args) def load(dir_path): state = torch.load(os.path.join(dir_path, 'best.pt')) state_dict = state["model"] model.load_state_dict(state_dict) print("loaded") nsml.load(args.checkpoint_name, load_fn=load, session=session) models = [model.cuda()] elif args.path == "pretrain": from nsml import DATASET_PATH from fairseq import checkpoint_utils data_token = "en-de" pretrained_path = "{}/train/pretrained_models/maskPredict_{}/checkpoint_best.pt".format( DATASET_PATH, data_token.split(".")[-1].replace("-", "_")) print("| loading", pretrained_path) model = task.build_model(args) state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_path) model.load_state_dict(state["model"], strict=True) models = [model.cuda()] elif args.path.startswith("wb://"): print("| loading wb checkpoint", args.path) import wandb wandb.restore("best.pt", args.path.replace("wb://", ""), root="/tmp/") assert os.path.exists("/tmp/best.pt") state = torch.load("/tmp/best.pt") model = task.build_model(args) model.load_state_dict(state["model"]) models = [model.cuda()] elif args.path.startswith("http://"): print("| loading http checkpoint", args.path) url = "http://trains.deeplearn.org:8081/{}".format( args.path.replace("http://", "")) os.system("curl -o /tmp/model.pt {}".format(url)) state = torch.load("/tmp/model.pt") model = task.build_model(args) model.load_state_dict(state["model"]) models = [model.cuda()] else: print('| loading model(s) from {}'.format(args.path)) models, _ = utils.load_ensemble_for_inference( args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides)) models = [model.cuda() for model in models] original_target_dataset = None assert args.original_target if args.original_target: original_target_dataset = IndexedCachedDataset(args.original_target, fix_lua_indexing=True) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, num_shards=args.num_shards, shard_id=args.shard_id, ).next_epoch_itr(shuffle=False) results = [] scorer = pybleu.PyBleuScorer() num_sentences = 0 has_target = True timer = TimeMeter() rel_reward_log = [] with progress_bar.build_progress_bar(args, itr) as t: translations = generate_batched_itr( t, strategy, models, tgt_dict, length_beam_size=args.length_beam, use_gold_target_len=args.gold_target_len) for sample_id, src_tokens, target_tokens, hypos, logp in translations: has_target = target_tokens is not None target_tokens = target_tokens.int().cpu() if has_target else None # Either retrieve the original sentences or regenerate them from tokens. distill_str = dict.string(target_tokens, args.remove_bpe, escape_unk=True) hypo_str = dict.string(hypos, args.remove_bpe, escape_unk=True) hypo_str_bpe = dict.string(hypos, None, escape_unk=True) # Compute reward original_target_dataset.prefetch([sample_id]) orig_target = dict.string(original_target_dataset[sample_id], args.remove_bpe, escape_unk=True) hypo_reward = smoothed_bleu(hypo_str.split(), orig_target.split()) distill_reward = smoothed_bleu(distill_str.split(), orig_target.split()) rel_reward = hypo_reward - distill_reward rel_reward_log.append(rel_reward) print("{} | {:.4f} | {:.4f} | {}".format(sample_id, rel_reward, logp, hypo_str_bpe)) print("mean rel reward:", np.mean(rel_reward_log))