Example #1
0
 def indexed_dataset(path, dictionary):
     if self.args.raw_text:
         return IndexedRawTextDataset(path, dictionary)
     elif IndexedDataset.exists(path):
         if self.args.lazy_load:
             return IndexedDataset(path, fix_lua_indexing=True)
         else:
             return IndexedCachedDataset(path, fix_lua_indexing=True)
     return None
Example #2
0
 def indexed_dataset(path, dictionary, copy_ext_dict=False, src_dataset=None):
     if self.args.raw_text:
         return IndexedRawTextDataset(path, dictionary, copy_ext_dict=copy_ext_dict, src_dataset=src_dataset)
     elif IndexedDataset.exists(path):
         if self.args.lazy_load:
             return IndexedDataset(path, fix_lua_indexing=True)
         else:
             return IndexedCachedDataset(path, fix_lua_indexing=True)
     return None
Example #3
0
 def split_exists(split, src, tgt, lang):
     filename = os.path.join(
         self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
     if self.args.raw_text and IndexedRawTextDataset.exists(filename):
         return True
     elif not self.args.raw_text and IndexedCachedDataset.exists(
             filename):
         return True
     return False
Example #4
0
 def indexed_dataset(path, dictionary, ex_dict=None, is_tgt=False):
     if self.args.segment:
         #if self.args.raw_text:
         return IndexedRawTextSegDataset(path, dictionary, ex_dict, is_tgt)
     else:
         if self.args.raw_text:
             return IndexedRawTextDataset(path, dictionary)
         elif IndexedDataset.exists(path):
             return IndexedCachedDataset(path, fix_lua_indexing=True)
     return None
 def indexed_dataset(path):
     assert IndexedDataset.exists(path), f'IndexedDataset.exists({path})'
     # if self.args.raw_text:
     #     return IndexedRawTextDataset(path, dictionary)
     # elif IndexedDataset.exists(path):
     #     if self.args.lazy_load:
     #         return IndexedDataset(path, fix_lua_indexing=True)
     #     else:
     #         return IndexedCachedDataset(path, fix_lua_indexing=True)
     # return None
     return IndexedCachedDataset(path, fix_lua_indexing=True)
Example #6
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """

        loaded_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(self.args.data, split_k)

            if self.args.raw_text and IndexedRawTextDataset.exists(path):
                ds = IndexedRawTextDataset(path, self.dictionary)
            elif not self.args.raw_text and IndexedDataset.exists(path):
                if self.args.lazy_load:
                    ds = IndexedDataset(path, fix_lua_indexing=True)
                else:
                    ds = IndexedCachedDataset(path, fix_lua_indexing=True)
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))

            loaded_datasets.append(
                TokenBlockDataset(
                    ds, ds.sizes, self.args.tokens_per_sample,
                    pad=self.dictionary.pad(), eos=self.dictionary.eos(),
                    break_mode=self.args.sample_break_mode, include_targets=True,
                )
            )

            print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none'

        self.datasets[split] = MonolingualDataset(
            dataset, sizes, self.dictionary, self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True,
            targets=self.targets,
        )
Example #7
0
    def load_dataset(self, split, combine=False):
        """Load a given dataset split.
        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        dataset_map = OrderedDict()

        for lang in self.langs2id.keys():
            if self.default_key is None:
                self.default_key = lang
            # Datasets are expected to be in "split.lang" format (Eg: train.en)
            language_split = '{}.{}'.format(split, lang)
            path = os.path.join(self.args.data, language_split)

            if self.args.raw_text and IndexedRawTextDataset.exists(path):
                ds = IndexedRawTextDataset(path, self.dictionary)
            elif not self.args.raw_text and IndexedDataset.exists(path):
                if self.args.lazy_load:
                    ds = IndexedDataset(path, fix_lua_indexing=True)
                else:
                    ds = IndexedCachedDataset(path, fix_lua_indexing=True)
            else:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    language_split, self.args.data))

            # Since we append each block with the classification_token,
            # we need to effectively create blocks of length
            # tokens_per_sample-1
            block_dataset = TokenBlockDataset(
                dataset=ds,
                sizes=ds.sizes,
                block_size=self.args.tokens_per_sample - 1,
                pad=self.dictionary.pad(),
                eos=self.dictionary.eos())

            dataset_map[lang] = MaskedLMDataset(
                dataset=block_dataset,
                sizes=block_dataset.sizes,
                vocab=self.dictionary,
                pad_idx=self.dictionary.pad(),
                mask_idx=self.dictionary.mask(),
                classif_token_idx=self.dictionary.eos(),
                sep_token_idx=self.dictionary.eos(),
                shuffle=getattr(self.args, 'shuffle', False),
                has_pairs=False,
                segment_id=self.langs2id[lang],
                seed=self.seed,
            )

        self.datasets[split] = MultiCorpusSampledDataset(
            dataset_map, default_key=self.default_key)
        print('| {} {} {} examples'.format(self.args.data, split,
                                           len(self.datasets[split])))
Example #8
0
 def indexed_dataset(path, dictionary, cached=True, audio=False):
     if self.args.raw_text:
         return IndexedRawTextDataset(path, dictionary)
     elif IndexedDataset.exists(path):
         if cached:
             return IndexedCachedDataset(path,
                                         fix_lua_indexing=True,
                                         audio=audio)
         else:
             return IndexedDataset(path,
                                   fix_lua_indexing=True,
                                   audio=audio)
     return None
Example #9
0
    def _load_single_lang_dataset(self, split):
        loaded_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(self.args.data, split_k)

            if self.args.raw_text and IndexedRawTextDataset.exists(path):
                ds = IndexedRawTextDataset(path, self.dictionary)
            elif not self.args.raw_text and IndexedDataset.exists(path):
                if self.args.lazy_load:
                    ds = IndexedDataset(path, fix_lua_indexing=True)
                else:
                    ds = IndexedCachedDataset(path, fix_lua_indexing=True)
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(
                            split, self.args.data))

            # Since we append each block with the classification_token,
            # we need to effectively create blocks of length
            # tokens_per_sample-1
            loaded_datasets.append(
                TokenBlockDataset(
                    ds,
                    ds.sizes,
                    self.args.tokens_per_sample - 1,
                    pad=self.dictionary.pad(),
                    eos=self.dictionary.eos(),
                ))

            print('| {} {} {} examples'.format(self.args.data, split_k,
                                               len(loaded_datasets[-1])))

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        return dataset, sizes
Example #10
0
 def indexed_dataset(path, dictionary):
     return IndexedCachedDataset(path, fix_lua_indexing=True)
Example #11
0
    def load_dataset(self, split, combine=False):
        """
        Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        loaded_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(self.args.data, split_k)

            if self.args.raw_text and IndexedRawTextDataset.exists(path):
                ds = IndexedRawTextDataset(path, self.dictionary)
            elif not self.args.raw_text and IndexedDataset.exists(path):
                if self.args.lazy_load:
                    ds = IndexedDataset(path, fix_lua_indexing=True)
                else:
                    ds = IndexedCachedDataset(path, fix_lua_indexing=True)
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(
                            split, self.args.data))
            with data_utils.numpy_seed(self.seed + k):
                loaded_datasets.append(
                    BlockPairDataset(
                        ds,
                        self.dictionary,
                        ds.sizes,
                        self.args.tokens_per_sample,
                        break_mode=self.args.break_mode,
                    ))

            logger.info('{} {} {} examples'.format(self.args.data, split_k,
                                                   len(loaded_datasets[-1])))

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        self.datasets[split] = MaskedLMDataset(
            dataset=dataset,
            sizes=sizes,
            vocab=self.dictionary,
            pad_idx=self.dictionary.pad(),
            mask_idx=self.dictionary.mask(),
            classif_token_idx=self.dictionary.cls(),
            sep_token_idx=self.dictionary.sep(),
            shuffle=False,
            seed=self.seed,
        )
 def indexed_dataset(path):
     assert IndexedCachedDataset.exists(
         path), f'IndexedCachedDataset.exists({path})'
     return IndexedCachedDataset(path, fix_lua_indexing=True)
Example #13
0
def main(args, checkpoint_name="best"):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu
    torch.manual_seed(args.seed)

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))
    args.taskobj = task

    # Set dictionaries
    #src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary
    dict = tgt_dict

    # Load decoding strategy
    strategy = strategies.setup_strategy(args)

    # Load ensemble
    if args.path.startswith("nsml://"):
        print("| loading nsml checkpoint", args.path)
        import nsml
        session = args.path.replace("nsml://", "")
        model = task.build_model(args)

        def load(dir_path):
            state = torch.load(os.path.join(dir_path, 'best.pt'))
            state_dict = state["model"]
            model.load_state_dict(state_dict)
            print("loaded")

        nsml.load(args.checkpoint_name, load_fn=load, session=session)
        models = [model.cuda()]
    elif args.path == "pretrain":
        from nsml import DATASET_PATH
        from fairseq import checkpoint_utils
        data_token = "en-de"
        pretrained_path = "{}/train/pretrained_models/maskPredict_{}/checkpoint_best.pt".format(
            DATASET_PATH,
            data_token.split(".")[-1].replace("-", "_"))
        print("| loading", pretrained_path)
        model = task.build_model(args)
        state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_path)
        model.load_state_dict(state["model"], strict=True)
        models = [model.cuda()]
    elif args.path.startswith("wb://"):
        print("| loading wb checkpoint", args.path)
        import wandb
        wandb.restore("best.pt", args.path.replace("wb://", ""), root="/tmp/")
        assert os.path.exists("/tmp/best.pt")
        state = torch.load("/tmp/best.pt")
        model = task.build_model(args)
        model.load_state_dict(state["model"])
        models = [model.cuda()]
    elif args.path.startswith("http://"):
        print("| loading http checkpoint", args.path)
        url = "http://trains.deeplearn.org:8081/{}".format(
            args.path.replace("http://", ""))
        os.system("curl -o /tmp/model.pt {}".format(url))
        state = torch.load("/tmp/model.pt")
        model = task.build_model(args)
        model.load_state_dict(state["model"])
        models = [model.cuda()]
    else:
        print('| loading model(s) from {}'.format(args.path))
        models, _ = utils.load_ensemble_for_inference(
            args.path.split(':'),
            task,
            model_arg_overrides=eval(args.model_overrides))
        models = [model.cuda() for model in models]

    original_target_dataset = None
    assert args.original_target
    if args.original_target:
        original_target_dataset = IndexedCachedDataset(args.original_target,
                                                       fix_lua_indexing=True)

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    results = []
    scorer = pybleu.PyBleuScorer()
    num_sentences = 0
    has_target = True
    timer = TimeMeter()
    rel_reward_log = []

    with progress_bar.build_progress_bar(args, itr) as t:

        translations = generate_batched_itr(
            t,
            strategy,
            models,
            tgt_dict,
            length_beam_size=args.length_beam,
            use_gold_target_len=args.gold_target_len)
        for sample_id, src_tokens, target_tokens, hypos, logp in translations:

            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            # Either retrieve the original sentences or regenerate them from tokens.
            distill_str = dict.string(target_tokens,
                                      args.remove_bpe,
                                      escape_unk=True)
            hypo_str = dict.string(hypos, args.remove_bpe, escape_unk=True)
            hypo_str_bpe = dict.string(hypos, None, escape_unk=True)

            # Compute reward
            original_target_dataset.prefetch([sample_id])
            orig_target = dict.string(original_target_dataset[sample_id],
                                      args.remove_bpe,
                                      escape_unk=True)
            hypo_reward = smoothed_bleu(hypo_str.split(), orig_target.split())
            distill_reward = smoothed_bleu(distill_str.split(),
                                           orig_target.split())
            rel_reward = hypo_reward - distill_reward
            rel_reward_log.append(rel_reward)

            print("{} | {:.4f} | {:.4f} | {}".format(sample_id, rel_reward,
                                                     logp, hypo_str_bpe))
    print("mean rel reward:", np.mean(rel_reward_log))