Example #1
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        if self.args['task']['target_lang'] == 'code_tokens' and self.args['task'].get('code_types', False):
            attrs_mapping = {
                'attr': {self.token_dictionary.index('attr')},
                'num': {self.token_dictionary.index('Num')},
                'name': {self.token_dictionary.index('NameStore'),
                         self.token_dictionary.index('NameLoad')},
                'param': {self.token_dictionary.index('arg'),
                          self.token_dictionary.index('kwarg'),
                          self.token_dictionary.index('vararg')},
            }
        elif self.args['task']['target_lang'] == 'ast' and self.args['task'].get('code_types', False):
            attrs_mapping = {
                'attr': {self.token_dictionary.index('attr')},
                'num': {self.token_dictionary.index('Num')},
                'name': {self.token_dictionary.index('NameStore'),
                         self.token_dictionary.index('NameLoad')},
                'param': {self.token_dictionary.index('NameParam')},
            }
        else:
            attrs_mapping = None

        if attrs_mapping:
            reversed_attrs_mapping = {}
            for k, vs in attrs_mapping.items():
                if len(vs) > 1:
                    for v in vs:
                        reversed_attrs_mapping[v] = k
                else:
                    reversed_attrs_mapping[list(vs)[0]] = k
        else:
            reversed_attrs_mapping = None

        self.datasets[split] = load_token_dataset(
            data_path, split, self.args['task']['target_lang'], self.target_dictionary,
            attrs_mapping=attrs_mapping, reversed_attrs_mapping=reversed_attrs_mapping,
            attrs=self.args['task'].get('code_types', None),
            attr_dict=self.token_dictionary,
            dataset_impl=self.args['dataset']['dataset_impl'],
            truncate_target=self.args['dataset'].get('truncate_target', False),
            max_target_positions=self.max_positions(),
        )
Example #2
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        self.datasets[split] = load_code_dataset_mlm(self.args, epoch,
                                                     data_path, split,
                                                     self.source_dictionary,
                                                     combine)
Example #3
0
 def setup_task(cls, args, **kwargs):
     paths = utils.split_paths(args['task']['data'])
     assert len(paths) > 0
     assert len(paths) > 0
     # load dictionaries
     src_dicts = OrderedDict()
     for lang in args['task']['source_langs']:
         src_dicts[lang] = cls.load_dictionary(os.path.join(paths[0], '{}.dict.json'.format(lang)))
         LOGGER.info('[{}] dictionary: {} types'.format(lang, len(src_dicts[lang]) if lang != 'edges' else 0))
     tgt_dicts = OrderedDict()
     for lang in args['task']['target_langs']:
         tgt_dicts[lang] = cls.load_dictionary(os.path.join(paths[0], '{}.dict.json'.format(lang)))
         LOGGER.info('[{}] dictionary: {} types'.format(lang, len(tgt_dicts[lang])))
     return cls(args, src_dicts, tgt_dicts)
Example #4
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        # infer langcode
        srcs, tgts = self.args['task']['source_langs'], self.args['task']['target_langs']

        self.datasets[split] = load_langpair_dataset(
            data_path, split, srcs, self.src_dicts, tgts, self.tgt_dicts,
            dataset_impl=self.args['dataset']['dataset_impl'],
            src_max_tokens=self.args['dataset']['src_max_tokens'],
            tgt_max_tokens=self.args['dataset']['tgt_max_tokens'],
        )
Example #5
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """

        paths = utils.split_paths(args['task']['data'])
        assert len(paths) > 0
        # load dictionaries
        dictionary = cls.load_dictionary(
            os.path.join(paths[0],
                         '{}.dict.json'.format(args['task']['target_lang'])))
        LOGGER.info('[{}] dictionary: {} types'.format(
            args['task']['target_lang'], len(dictionary)))
        return cls(args, dictionary)
Example #6
0
 def setup_task(cls, args, **kwargs):
     paths = utils.split_paths(args['task']['data'])
     assert len(paths) > 0
     if args['dataset']['joined_dictionary']:
         modalities = sorted(args['task']['source_langs'] + args['task']['target_langs'])
         src_dicts = tgt_dicts = cls.load_dictionary(
             os.path.join(paths[0], '{}.dict.json'.format('_'.join(modalities))))
     else:
         src_dicts = {
             lang: cls.load_dictionary(os.path.join(paths[0], f'{lang}.dict.jsonl'))
             for lang in args['task']['source_langs']
         }
         tgt_dicts = {
             lang: cls.load_dictionary(os.path.join(paths[0], f'{lang}.dict.jsonl'))
             for lang in args['task']['target_langs']
         }
     return cls(args, src_dicts, tgt_dicts)
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        src, tgt = self.args['task']['source_lang'], self.args['task'][
            'target_lang']
        if split == 'train':
            src_aux, tgt_aux = self.args['task']['source_aux_lang'], self.args[
                'task']['target_aux_lang']
        else:
            src_aux, tgt_aux = None, None

        if self.args['model']['arch'] in [
                'nbow', 'conv1d_res', 'birnn', 'self_attn'
        ]:
            task_idx = kwargs.get('task_idx', 1)
            if split == 'valid':
                labels = self.args['dataset']['langs'][:task_idx + 1]
            else:
                labels = [self.args['dataset']['langs'][task_idx]]
            self.datasets[split] = load_tokens_dataset(
                data_path,
                split,
                src,
                self.source_dictionary,
                tgt,
                self.target_dictionary,
                dataset_impl=self.args['dataset']['dataset_impl'],
                src_max_tokens=self.args['dataset']['code_max_tokens'],
                tgt_max_tokens=self.args['dataset']['query_max_tokens'],
                src_aux=src_aux,
                tgt_aux=tgt_aux,
                fraction_using_func_name=self.args['task']
                ['fraction_using_func_name'],
                labels=labels,
                shuffle=(split == 'train'),
            )
        else:
            raise NotImplementedError
Example #8
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        if self.args['model']['arch'] == 'seqrnn':
            self.datasets[split] = load_token_dataset(
                data_path,
                split,
                self.args['task']['target_lang'],
                self.target_dictionary,
                dataset_impl=self.args['dataset']['dataset_impl'],
                ext=self.args['task']['ext'],
                max_target_positions=self.max_positions())
Example #9
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        # infer langcode
        src, tgt = self.args['task']['source_lang'], self.args['task']['target_lang']

        self.datasets[split] = load_langpair_dataset(
            data_path, split, src, self.src_dict, tgt, self.tgt_dict,
            dataset_impl=self.args['dataset']['dataset_impl'],
            left_pad_source=self.args['task']['left_pad_source'],
            max_source_positions=self.args['task']['max_source_positions'],
            src_aux=self.args['task']['source_aux'],
        )
Example #10
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        srcs, tgts = self.args['task']['source_langs'], self.args['task']['target_langs']
        assert len(self.args['dataset']['max_srcs']) == len(srcs) and \
               len(self.args['dataset']['max_tgts']) == len(tgts)
        self.datasets[split] = load_tokens_dataset(
            data_path, split, srcs, self.source_dictionaries, tgts, self.target_dictionaries,
            dataset_impl=self.args['dataset']['dataset_impl'],
            max_srcs=self.args['dataset']['max_srcs'],
            max_tgts=self.args['dataset']['max_tgts'],
            langs=self.args['dataset']['langs'],
            shuffle=(split == 'train'),
            fraction_using_func_name=self.args['task']['fraction_using_func_name'],
            sample_neg=(self.args['optimization'].get('sample_neg', False))
        )
Example #11
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        src, tgt = self.args['task']['source_lang'], self.args['task']['target_lang']

        self.datasets[split] = load_tokens_dataset(
            data_path, split, src, self.source_dictionary, tgt, self.target_dictionary,
            dataset_impl=self.args['dataset']['dataset_impl'],
            max_source_positions=self.args['dataset']['max_source_positions'],
            max_target_positions=self.args['dataset']['max_target_positions'],
            max_positions=self.args['dataset']['max_positions'],
            append_source_eos=self.args['dataset']['append_source_eos'],
            append_target_eos=self.args['dataset']['append_target_eos'],
            shuffle=(split == 'train'),
        )
Example #12
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        # options.eval_bool
        args['task']['left_pad_source'] = bool(args['task']['left_pad_source'])
        args['task']['left_pad_target'] = bool(args['task']['left_pad_target'])
        paths = utils.split_paths(args['task']['data'])
        assert len(paths) > 0
        # find language pair automatically
        if args['task']['source_lang'] is None or args['task'][
                'target_lang'] is None:
            # args['task'].source_lang, args['task'].target_lang = data_utils.infer_language_pair(args.data[0])
            args['task']['source_lang'], args['task'][
                'target_lang'] = data_utils.infer_language_pair(paths[0])
        if args['task']['source_lang'] is None or args['task'][
                'target_lang'] is None:
            raise Exception(
                'Could not infer language pair, please provide it explicitly')

        # load dictionaries
        src_dict = cls.load_dictionary(
            os.path.join(paths[0],
                         '{}.dict.json'.format(args['task']['source_lang'])))
        tgt_dict = cls.load_dictionary(
            os.path.join(paths[0],
                         '{}.dict.json'.format(args['task']['target_lang'])))

        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()
        LOGGER.info('[{}] dictionary: {} types'.format(
            args['task']['source_lang'], len(src_dict)))
        LOGGER.info('[{}] dictionary: {} types'.format(
            args['task']['target_lang'], len(tgt_dict)))

        return cls(args, src_dict, tgt_dict)
Example #13
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        paths = utils.split_paths(args['task']['data'])
        assert len(paths) > 0

        dict = args['task'].get('dict', None)
        dict_type = args['task'].get('dict_type', None)
        if dict is None and dict_type is None:

            src_dict = cls.load_dictionary(os.path.join(paths[0], '{}.dict.jsonl'.format(args['task']['source_lang'])))
            tgt_dict = cls.load_dictionary(os.path.join(paths[0], '{}.dict.jsonl'.format(args['task']['target_lang'])))
            assert src_dict.pad() == tgt_dict.pad()
            assert src_dict.eos() == tgt_dict.eos()
            assert src_dict.unk() == tgt_dict.unk()
            LOGGER.info('[{}] dictionary: {} types'.format(args['task']['source_lang'], len(src_dict)))
            LOGGER.info('[{}] dictionary: {} types'.format(args['task']['target_lang'], len(tgt_dict)))
        else:
            raise NotImplementedError
        return cls(args, src_dict, tgt_dict)
Example #14
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        paths = utils.split_paths(args['task']['data'])
        assert len(paths) > 0
        # load dictionaries
        src_dicts = OrderedDict()
        for lang in args['task']['source_langs']:
            src_dicts[lang] = cls.load_dictionary(
                os.path.join(paths[0], '{}.dict.json'.format(lang)))
            LOGGER.info('[{}] dictionary: {} types'.format(
                lang,
                len(src_dicts[lang]) if lang != 'edges' else 0))
        tgt_dicts = OrderedDict()
        for lang in args['task']['target_langs']:
            tgt_dicts[lang] = cls.load_dictionary(
                os.path.join(paths[0], '{}.dict.json'.format(lang)))
            LOGGER.info('[{}] dictionary: {} types'.format(
                lang, len(tgt_dicts[lang])))
        return cls(args, src_dicts, tgt_dicts)
Example #15
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """

        paths = utils.split_paths(args['task']['data'])
        assert len(paths) > 0
        # load dictionaries
        dict_file = os.path.join(
            paths[0], '{}.dict.jsonl'.format(args['task']['target_lang']))
        dictionary = cls.load_dictionary(dict_file)
        LOGGER.info('[{}] dictionary: {} types'.format(
            args['task']['target_lang'], len(dictionary)))
        token_file = os.path.join(paths[0], 'code_types.dict.jsonl')
        if os.path.exists(token_file):
            token_dictionary = cls.load_dictionary(token_file)
            LOGGER.info('[code_tokens] dictionary: {} types'.format(
                len(token_dictionary)))
        else:
            token_dictionary = None
        return cls(args, dictionary, token_dictionary)
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        # infer langcode
        src, tgt = self.args['task']['source_lang'], self.args['task'][
            'target_lang']

        self.datasets[split] = load_langpair_dataset(
            data_path,
            split,
            src,
            self.src_dict,
            self.type_dict,
            tgt,
            self.tgt_dict,
            dataset_impl=self.args['dataset']['dataset_impl'],
            left_pad_source=self.args['task']['left_pad_source'],
            left_pad_target=self.args['task']['left_pad_target'],
            max_source_positions=self.args['task']['max_source_positions'],
            max_target_positions=self.args['task']['max_target_positions'],
            load_alignments=self.args['task']['load_alignments'],
            truncate_source=self.args['task']['truncate_source'],
            truncate_target=self.args['task']['truncate_target'],
            append_eos_to_target=self.args['task']['append_eos_to_target'],
            portion=self.args['dataset'].get('portion', None),
            path_num=self.args['dataset'].get(f'{split}_path_num', 200),
            max_subtoken_len=self.args['dataset'].get(f'max_subtoken_len',
                                                      None),
            max_path_len=self.args['dataset'].get(f'max_path_len', None),
        )
Example #17
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        # infer langcode
        src, tgt = self.args['task']['source_lang'], self.args['task'][
            'target_lang']

        sp = spm.SentencePieceProcessor()
        sp.load(self.args['dataset']['src_sp'])

        self.datasets[split] = load_codetype_dataset(
            data_path,
            split,
            src,
            self.src_dict,
            tgt,
            self.tgt_dict,
            sp,
            combine=combine,
            dataset_impl=self.args['dataset']['dataset_impl'],
            # upsample_primary=self.args['task']['upsample_primary'],
            # left_pad_source=self.args['task']['left_pad_source'],
            # left_pad_target=self.args['task']['left_pad_target'],
            max_source_positions=self.args['task']['max_source_positions'],
            # max_target_positions=self.args['task']['max_target_positions'],
            # load_alignments=self.args['task']['load_alignments'],
            # truncate_source=self.args['task']['truncate_source'],
            # append_eos_to_target=self.args['task']['append_eos_to_target'],
        )
    def setup_task(cls, args, **kwargs):
        paths = utils.split_paths(args['task']['data'])
        assert len(paths) > 0
        # dictionary = Dictionary.load(os.path.join(paths[0], 'dict.code.txt'))
        if args['dataset']['joined_dictionary']:
            src_dict = cls.load_dictionary(
                os.path.join(paths[0], '{}.dict.txt'.format(
                    args['task']
                    ['source_lang'])))  # args['task']['source_lang']
            tgt_dict = src_dict
        else:
            src_dict = cls.load_dictionary(
                os.path.join(paths[0], '{}.dict.txt'.format(
                    args['task']
                    ['source_lang'])))  # args['task']['source_lang']
            tgt_dict = cls.load_dictionary(
                os.path.join(paths[0], '{}.dict.txt'.format(
                    args['task']['target_lang'])))

        src_dict.add_symbol(constants.S_SEP)
        src_dict.add_symbol(constants.S2S_SEP)
        src_dict.add_symbol(constants.CLS)
        src_dict.add_symbol(constants.T_MASK)
        src_dict.add_symbol(constants.SEP)

        tgt_dict.add_symbol(constants.S2S_BOS)
        tgt_dict.add_symbol(constants.T_MASK)
        tgt_dict.add_symbol(constants.SEP)
        print('<T_MASK> id is', src_dict.index('<T_MASK>'))
        print('<T_MASK> id is', tgt_dict.index('<T_MASK>'))

        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()

        return cls(args, src_dict, tgt_dict)
Example #19
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        # paths = self.args.data.split(':')
        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        # split_path = os.path.join(data_path, split)

        if self.langs is None:
            languages = sorted([
                name for name in os.listdir(data_path)
                if os.path.isdir(os.path.join(data_path, name))
            ])
        else:
            languages = self.langs  # .split(',')
            # for name in languages:
            #     assert os.path.exists(os.path.join(data_path, name)), FileNotFoundError(os.path.join(data_path, name))

        LOGGER.info("| Training on {0} languages: {1}".format(
            len(languages), languages))
        LOGGER.info("| Language to id mapping: ",
                    {lang: id
                     for id, lang in enumerate(languages)})

        mask_whole_words = get_whole_word_mask(self.args, self.dictionary)
        lang_datasets = []
        for language in languages:
            # split_path = os.path.join(data_path, language, split)
            if language == 'docstring':
                split_path = os.path.join(data_path, language,
                                          f"{split}.docstring.spm")
            else:
                split_path = os.path.join(data_path, language,
                                          f"{split}.code.spm")
            # split_path = os.path.join(data_path, language, f"{split}.spm.{language}")
            # dataset = data_utils.load_indexed_dataset(
            #     split_path,
            #     self.source_dictionary,
            #     self.args['dataset']['dataset_impl'],
            #     combine=combine,
            # )
            dataset = load_lang_dataset_denoising(
                path=split_path,
                impl=self.args['dataset']['dataset_impl'],
                dict=self.source_dictionary)

            if dataset is None:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, split_path))

            dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(dataset, self.source_dictionary.eos()),
                    self.args['task']['max_source_positions'] -
                    3),  # <lang>, <bos>, <eos>
                token=self.source_dictionary.eos(),
            )

            end_token = self.source_dictionary.index('[{}]'.format(language)) \
                if self.args['task']['add_lang_token'] else self.source_dictionary.eos()

            # create continuous blocks of tokens
            dataset = TokenBlockDataset(
                dataset,
                dataset.sizes,
                self.args['task']['tokens_per_sample'] -
                2,  # one less for <s> and one for </s>
                pad=self.source_dictionary.pad(),
                eos=end_token,
                break_mode=self.args['task']['sample_break_mode'],
                document_sep_len=0,
            )
            LOGGER.info('| loaded {} blocks from: {}'.format(
                len(dataset), split_path))

            # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
            dataset = PrependTokenDataset(dataset,
                                          self.source_dictionary.bos())
            dataset = AppendTokenDataset(dataset, end_token)

            lang_dataset = DenoisingDataset(
                dataset,
                dataset.sizes,
                self.dictionary,
                self.mask_idx,
                mask_whole_words,
                shuffle=self.args['dataset']['shuffle_instance'],
                seed=self.seed,
                args=self.args,
                eos=None if not self.args['task']['add_lang_token'] else
                self.source_dictionary.index('[{}]'.format(language)),
            )
            lang_datasets.append(lang_dataset)

        dataset_lengths = np.array(
            [len(d) for d in lang_datasets],
            dtype=float,
        )
        LOGGER.info('| loaded total {} blocks for all languages'.format(
            dataset_lengths.sum(), ))
        if split == self.args['dataset']['train_subset']:
            # For train subset, additionally up or down sample languages.
            sample_probs = self._get_sample_prob(dataset_lengths)
            LOGGER.info(
                "| Sample probability by language: ", {
                    lang: "{0:.4f}".format(sample_probs[id])
                    for id, lang in enumerate(languages)
                })
            size_ratio = (sample_probs *
                          dataset_lengths.sum()) / dataset_lengths
            LOGGER.info(
                "| Up/Down Sampling ratio by language: ", {
                    lang: "{0:.2f}".format(size_ratio[id])
                    for id, lang in enumerate(languages)
                })

            resampled_lang_datasets = [
                ResamplingDataset(
                    lang_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args['common']['seed'],
                    epoch=epoch,
                    replace=size_ratio[i] >= 1.0,
                ) for i, d in enumerate(lang_datasets)
            ]
            dataset = ConcatDataset(resampled_lang_datasets, )
        else:
            dataset = ConcatDataset(lang_datasets)
            lang_splits = [split]
            # for lang_id, lang_dataset in enumerate(lang_datasets):
            #     split_name = split + '_' + languages[lang_id]
            #     lang_splits.append(split_name)
            #     self.datasets[split_name] = lang_dataset

            if split in self.args['dataset']['valid_subset']:
                self.args['dataset']['valid_subset'] = self.args['dataset'][
                    'valid_subset'].replace(split, ','.join(lang_splits))

        with data_utils.numpy_seed(self.args['common']['seed'] + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.datasets[split] = SortDataset(
            dataset,
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )
Example #20
0
def main(args, out_file=None):
    use_cuda = torch.cuda.is_available() and not args['common']['cpu']

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args['dataset']['gen_subset'])

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load ensemble
    LOGGER.info('loading model(s) from {}'.format(args['eval']['path']))
    models, _ = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args['eval']['path']),
        arg_overrides=eval(args['eval']['model_overrides']),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None
            if args['eval']['no_beamable_mm'] else args['eval']['beam'],
            need_attn=args['eval']['print_alignment'],
        )

        if use_cuda:
            device = os.environ.get('CUDA_VISIBALE_DEVICES',
                                    [0])[0]  # get first device as default
            torch.cuda.set_device(f'cuda:{device}')
            model = model.cuda()
        if args['common']['fp16'] and use_cuda:
            model.half()

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args['dataset']['gen_subset']),
        max_tokens=args['dataset']['max_tokens'],
        max_sentences=args['eval']['max_sentences'],
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args['dataset']
        ['skip_invalid_size_inputs_valid_test'],
        required_batch_size_multiple=args['dataset']
        ['required_batch_size_multiple'],
        num_shards=args['dataset']['num_shards'],
        shard_id=args['dataset']['shard_id'],
        num_workers=args['dataset']['num_workers'],
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args['common']['log_format'],
        log_interval=args['common']['log_interval'],
        default_log_format=('tqdm' if not args['common']['no_progress_bar']
                            else 'none'),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(models, args)

    sources, hypotheses, references = dict(), dict(), dict()

    for sample in progress:
        torch.cuda.empty_cache()

        sample = move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        gen_timer.start()
        hypos = task.inference_step(generator,
                                    models,
                                    sample,
                                    bos_token=tgt_dict.bos())
        num_generated_tokens = sum(len(h[0]['tokens'])
                                   for h in hypos)  # TODO: warning
        gen_timer.stop(num_generated_tokens)

        for i, sample_id in enumerate(sample['id'].tolist()):
            has_target = sample['target'] is not None

            # Remove padding
            src_tokens = utils.strip_pad(
                sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
            target_tokens = None
            if has_target:
                target_tokens = utils.strip_pad(sample['target'][i, :],
                                                tgt_dict.pad()).int().cpu()

            hypos_tokens = utils.strip_eos(hypos[i][0]['tokens'],
                                           tgt_dict.eos()).int().cpu()
            # Either retrieve the original sentences or regenerate them from tokens.
            if src_dict is not None:
                src_str = src_dict.string(src_tokens,
                                          args['eval']['remove_bpe'])
            else:
                src_str = "0"
            if has_target:
                target_str = tgt_dict.string(target_tokens,
                                             args['eval']['remove_bpe'],
                                             escape_unk=True)

            hypo_str = tgt_dict.string(hypos_tokens,
                                       args['eval']['remove_bpe'])

            sources[sample_id] = [src_str]
            hypotheses[sample_id] = [hypo_str]
            references[sample_id] = [target_str]

    bleu, rouge_l, meteor = \
        summarization_metrics.eval_accuracies(hypotheses, references, filename=out_file, mode='test')
    LOGGER.info('BLEU: {:.2f}\t ROUGE-L: {:.2f}\t METEOR: {:.2f}'.format(
        bleu, rouge_l, meteor))
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        if self.args['task']['target_lang'] == 'code_tokens' and self.args[
                'task'].get('code_types', False):
            attrs_mapping = {
                'attr': {self.token_dictionary.index('attr')},
                'num': {self.token_dictionary.index('Num')},
                'name': {
                    self.token_dictionary.index('NameStore'),
                    self.token_dictionary.index('NameLoad')
                },
                'param': {
                    self.token_dictionary.index('arg'),
                    self.token_dictionary.index('kwarg'),
                    self.token_dictionary.index('vararg')
                },
            }
        elif self.args['task']['target_lang'] == 'ast' and self.args[
                'task'].get('code_types', False):
            attrs_mapping = {
                'attr': {self.token_dictionary.index('attr')},
                'num': {self.token_dictionary.index('Num')},
                'name': {
                    self.token_dictionary.index('NameStore'),
                    self.token_dictionary.index('NameLoad')
                },
                'param': {self.token_dictionary.index('NameParam')},
            }
        else:
            attrs_mapping = None

        if attrs_mapping:
            reversed_attrs_mapping = {}
            for k, vs in attrs_mapping.items():
                if len(vs) > 1:
                    for v in vs:
                        reversed_attrs_mapping[v] = k
                else:
                    reversed_attrs_mapping[list(vs)[0]] = k
        else:
            reversed_attrs_mapping = None

        task_idx = kwargs.get('task_idx', 1)
        if split == 'train':
            cur_task = self.args['task']['task_pipeline'][task_idx]
            init_from_scratch = task_idx == 1 and self.args['task'][
                'task_pipeline'][0] == 'scratch'
            sample_portion = self.args['task']['sample_portion']
            if sample_portion is not None and not init_from_scratch:
                # 1) sample protion data from previous tasks, and 2) first task is not scratch
                prev_tasks = self.args['task']['task_pipeline'][:task_idx]
            else:
                prev_tasks = []

            self.datasets[split] = load_kd_token_dataset(
                data_path,
                split,
                self.args['task']['target_lang'],
                self.target_dictionary,
                attrs_mapping=attrs_mapping,
                reversed_attrs_mapping=reversed_attrs_mapping,
                attrs=self.args['task'].get('code_types', None),
                attr_dict=self.token_dictionary,
                dataset_impl=self.args['dataset']['dataset_impl'],
                truncate_target=self.args['dataset'].get(
                    'truncate_target', False),
                max_target_positions=self.max_positions(),
                # lifelong
                cur_task=cur_task,
                prev_tasks=prev_tasks,
                sample_portion=self.args['task']['sample_portion'],
                # kd
                topk=self.args['kd']['gen_topk'],
                distill_topk=self.args['kd']['distill_topk'],
                teacher_out_dir=self.args['kd']['teacher_out_dir'],
            )
        elif split == 'test':
            data_paths = [
                os.path.join(data_path,
                             self.args['task']['task_pipeline'][task_idx])
            ]
            self.datasets[split] = load_inference_token_dataset(
                data_paths,
                split,
                self.args['task']['target_lang'],
                self.target_dictionary,
                attrs_mapping=attrs_mapping,
                reversed_attrs_mapping=reversed_attrs_mapping,
                attrs=self.args['task'].get('code_types', None),
                attr_dict=self.token_dictionary,
                dataset_impl=self.args['dataset']['dataset_impl'],
                truncate_target=self.args['dataset'].get(
                    'truncate_target', False),
                max_target_positions=self.max_positions(),
            )
        else:
            # data_paths = [
            #     os.path.join(data_path, task_name)
            #     for task_name in self.args['task']['task_pipeline'][:task_idx + 1]
            # ]
            # data_paths = [
            #     os.path.join(data_path, self.args['task']['task_pipeline'][task_idx])
            # ]
            cur_task = self.args['task']['task_pipeline'][task_idx]
            # self.datasets[split] = load_inference_token_dataset(
            #     data_paths, split, self.args['task']['target_lang'], self.target_dictionary,
            #     attrs_mapping=attrs_mapping, reversed_attrs_mapping=reversed_attrs_mapping,
            #     attrs=self.args['task'].get('code_types', None),
            #     attr_dict=self.token_dictionary,
            #     dataset_impl=self.args['dataset']['dataset_impl'],
            #     truncate_target=self.args['dataset'].get('truncate_target', False),
            #     max_target_positions=self.max_positions(),
            # )
            self.datasets[split] = load_kd_token_dataset(
                data_path,
                split,
                self.args['task']['target_lang'],
                self.target_dictionary,
                attrs_mapping=attrs_mapping,
                reversed_attrs_mapping=reversed_attrs_mapping,
                attrs=self.args['task'].get('code_types', None),
                attr_dict=self.token_dictionary,
                dataset_impl=self.args['dataset']['dataset_impl'],
                truncate_target=self.args['dataset'].get(
                    'truncate_target', False),
                max_target_positions=self.max_positions(),
                # lifelong
                cur_task=cur_task,
                prev_tasks=[],
                sample_portion=0,
                # kd
                topk=self.args['kd']['gen_topk'],
                distill_topk=self.args['kd']['distill_topk'],
                teacher_out_dir=self.args['kd']['teacher_out_dir'],
            )
Example #22
0
def _main(args, output_file):
    if args['dataset']['max_tokens'] is None and args['dataset'][
            'max_sentences'] is None:
        args['dataset']['max_tokens'] = 12000
    LOGGER.info(args)

    use_cuda = torch.cuda.is_available() and not args['common']['cpu']

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args['dataset']['gen_subset'])

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    LOGGER.info('loading model(s) from {}'.format(args['eval']['path']))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args['eval']['path']),
        arg_overrides=eval(args['eval']['model_overrides']),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None
            if args['eval']['no_beamable_mm'] else args['eval']['beam'],
            need_attn=args['eval']['print_alignment'],
        )
        if _model_args['common']['fp16']:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args['eval']['replace_unk'])

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args['dataset']['gen_subset']),
        max_tokens=args['dataset']['max_tokens'],
        max_sentences=args['eval']['max_sentences'],
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=_model_args['dataset']
        ['skip_invalid_size_inputs_valid_test'],
        required_batch_size_multiple=_model_args['dataset']
        ['required_batch_size_multiple'],
        num_shards=_model_args['dataset']['num_shards'],
        shard_id=_model_args['dataset']['shard_id'],
        num_workers=_model_args['dataset']['num_workers'],
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=_model_args['common']['log_format'],
        log_interval=_model_args['common']['log_interval'],
        default_log_format=('tqdm'
                            if not _model_args['common']['no_progress_bar']
                            else 'none'),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    # for sample in tqdm(progress, total=len(progress)):
    sources, hypotheses, references = dict(), dict(), dict()

    for sample in progress:
        torch.cuda.empty_cache()
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        # prefix_tokens = None
        # if args['eval']['prefix_size'] > 0:
        #     prefix_tokens = sample['target'][:, :args['eval']['prefix_size']]

        gen_timer.start()
        hypos = task.inference_step(generator, models, sample)
        # gen_out = task.sequence_generator.generate(model, sample)
        num_generated_tokens = sum(len(h[0]['tokens'])
                                   for h in hypos)  # TODO: warning
        gen_timer.stop(num_generated_tokens)

        for i, sample_id in enumerate(sample['id'].tolist()):
            has_target = sample['target'] is not None

            # Remove padding
            src_tokens = utils.strip_pad(
                sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
            target_tokens = None
            if has_target:
                target_tokens = utils.strip_pad(sample['target'][i, :],
                                                tgt_dict.pad()).int().cpu()

            hypos_tokens = utils.strip_eos(hypos[i][0]['tokens'],
                                           tgt_dict.eos()).int().cpu()
            # Either retrieve the original sentences or regenerate them from tokens.
            # if align_dict is not None:
            #     src_str = task.dataset(args['dataset']['gen_subset']).src.get_original_text(sample_id)
            #     target_str = task.dataset(args['dataset']['gen_subset']).tgt.get_original_text(sample_id)
            # else:
            if src_dict is not None:
                src_str = src_dict.string(src_tokens,
                                          args['eval']['remove_bpe'])
            else:
                src_str = ""
            if has_target:
                target_str = tgt_dict.string(target_tokens,
                                             args['eval']['remove_bpe'],
                                             escape_unk=True)

            # hypo_tokens = tgt_dict.encode_line(hypo_str, add_if_not_exist=True)
            hypo_str = tgt_dict.string(hypos_tokens,
                                       args['eval']['remove_bpe'])

            sources[sample_id] = [src_str]
            hypotheses[sample_id] = [hypo_str]
            references[sample_id] = [target_str]

            if not args['eval']['quiet']:
                if src_dict is not None:
                    print('S-{}\t{}'.format(sample_id, src_str),
                          file=output_file)
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str),
                          file=output_file)

                print('H-{}\t{}'.format(sample_id, hypo_str), file=output_file)

    filename = os.path.join(os.path.dirname(__file__), 'config',
                            'predict.json')
    LOGGER.info('write predicted file at {}'.format(filename))
    bleu, rouge_l, meteor = eval_utils.eval_accuracies(hypotheses,
                                                       references,
                                                       filename=filename,
                                                       mode='test')
    LOGGER.info('BLEU: {:.2f}\t ROUGE-L: {:.2f}\t METEOR: {:.2f}'.format(
        bleu, rouge_l, meteor))
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        # infer langcode
        src, tgt = self.args['task']['source_lang'], self.args['task'][
            'target_lang']

        task_idx = kwargs.get('task_idx', 1)
        if split == 'train':
            cur_task = self.args['task']['task_pipeline'][task_idx]
            init_from_scratch = task_idx == 1 and self.args['task'][
                'task_pipeline'][0] == 'scratch'
            sample_portion = self.args['task']['sample_portion']
            if sample_portion is not None and not init_from_scratch:
                # 1) sample protion data from previous tasks, and 2) first task is not scratch
                prev_tasks = self.args['task']['task_pipeline'][:task_idx]
            else:
                prev_tasks = []
            self.datasets[split] = load_langpair_dataset(
                data_path,
                split,
                src,
                self.src_dict,
                tgt,
                self.tgt_dict,
                dataset_impl=self.args['dataset']['dataset_impl'],
                left_pad_source=self.args['task']['left_pad_source'],
                left_pad_target=self.args['task']['left_pad_target'],
                max_source_positions=self.args['task']['max_source_positions'],
                max_target_positions=self.args['task']['max_target_positions'],
                load_alignments=self.args['task']['load_alignments'],
                truncate_source=self.args['task']['truncate_source'],
                truncate_target=self.args['task']['truncate_target'],
                prepend_bos=kwargs.get('prepend_bos', True),
                # lifelong
                prev_tasks=prev_tasks,
                cur_task=cur_task,
                sample_portion=sample_portion,
            )
        elif split == 'test':
            data_paths = [
                os.path.join(data_path,
                             self.args['task']['task_pipeline'][task_idx])
            ]
            self.datasets[split] = load_langpair_dataset(
                data_paths,
                split,
                src,
                self.src_dict,
                tgt,
                self.tgt_dict,
                dataset_impl=self.args['dataset']['dataset_impl'],
                left_pad_source=self.args['task']['left_pad_source'],
                left_pad_target=self.args['task']['left_pad_target'],
                max_source_positions=self.args['task']['max_source_positions'],
                max_target_positions=self.args['task']['max_target_positions'],
                load_alignments=self.args['task']['load_alignments'],
                truncate_source=self.args['task']['truncate_source'],
                truncate_target=self.args['task']['truncate_target'],
                prepend_bos=kwargs.get('prepend_bos', True),
            )
        else:
            data_paths = [
                os.path.join(data_path, task_name)
                for task_name in self.args['task']['task_pipeline'][:task_idx +
                                                                    1]
            ]
            self.datasets[split] = load_inference_langpair_dataset(
                data_paths,
                split,
                src,
                self.src_dict,
                tgt,
                self.tgt_dict,
                dataset_impl=self.args['dataset']['dataset_impl'],
                left_pad_source=self.args['task']['left_pad_source'],
                left_pad_target=self.args['task']['left_pad_target'],
                max_source_positions=self.args['task']['max_source_positions'],
                max_target_positions=self.args['task']['max_target_positions'],
                load_alignments=self.args['task']['load_alignments'],
                truncate_source=self.args['task']['truncate_source'],
                truncate_target=self.args['task']['truncate_target'],
                prepend_bos=kwargs.get('prepend_bos', True),
            )
Example #24
0
def _main(args, output_file):
    if args['dataset']['max_tokens'] is None and args['dataset'][
            'max_sentences'] is None:
        args['dataset']['max_tokens'] = 12000

    use_cuda = torch.cuda.is_available() and not args['common']['cpu']
    if use_cuda:
        device = os.environ.get('CUDA_VISIBALE_DEVICES',
                                [0])[0]  # get first device as default
        torch.cuda.set_device(f'cuda:{device}')

    # Load dataset splits
    task = tasks.setup_task(args)

    # Load ensemble
    LOGGER.info('loading model(s) from {}'.format(args['eval']['path']))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args['eval']['path']),
        arg_overrides=eval(args['eval']['model_overrides']),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        if _model_args['common']['fp16']:
            model.half()
        if use_cuda:
            model.cuda()

    sequence_completor = task.build_completor(models, args)

    subsets = [
        args['dataset']['train_subset'],
        args['dataset']['valid_subset'],
        args['dataset']['gen_subset'],
    ]
    for subset in subsets:
        task.load_dataset(subset, shuffle=False)
        task.dataset(subset).shuffle = False

        # Load dataset (possibly sharded)
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args['dataset']['max_tokens'],
            max_sentences=args['eval']['max_sentences_eval'],
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                *[model.max_positions() for model in models]),
            ignore_invalid_inputs=_model_args['dataset']
            ['skip_invalid_size_inputs_valid_test'],
            required_batch_size_multiple=_model_args['dataset']
            ['required_batch_size_multiple'],
            num_shards=_model_args['dataset']['num_shards'],
            shard_id=_model_args['dataset']['shard_id'],
            num_workers=_model_args['dataset']['num_workers'],
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=_model_args['common']['log_format'],
            log_interval=_model_args['common']['log_interval'],
            default_log_format=('tqdm'
                                if not _model_args['common']['no_progress_bar']
                                else 'none'),
        )

        topk = args['kd']['gen_topk']
        out_idx, out_prob = [], []
        with torch.no_grad():
            for sample in progress:
                torch.cuda.empty_cache()
                sample = move_to_cuda(sample) if use_cuda else sample
                if 'net_input' not in sample:
                    continue
                net_output = sequence_completor.generate([model],
                                                         sample,
                                                         prefix_tokens=None)
                topk_prob, topk_ids = torch.topk(net_output[0], topk, dim=-1)
                # ignore pad
                non_padding_mask = sample['net_input'][
                    'src_tokens'] != task.target_dictionary.pad()
                if use_cuda:
                    topk_prob, topk_ids = topk_prob.cpu(), topk_ids.cpu()
                    non_padding_mask = non_padding_mask.cpu()
                for idx in range(topk_prob.size(0)):
                    out_idx.append(
                        topk_ids[idx,
                                 ...][non_padding_mask[idx,
                                                       ...]].view(-1).tolist())
                    out_prob.append(topk_prob[idx, ...][non_padding_mask[
                        idx, ...]].view(-1).tolist())
        assert len(out_idx) == len(out_prob) == len(task.dataset(subset)), \
            Exception(len(out_idx), len(out_prob), len(task.dataset(subset)))
        TeacherOutDataset.save_bin(
            prefix=os.path.join(args['checkpoint']['save_dir'],
                                f'{subset}.top{topk}_idx'),
            data_list=out_idx,
            dtype=np.int32,
        )
        TeacherOutDataset.save_bin(
            prefix=os.path.join(args['checkpoint']['save_dir'],
                                f'{subset}.top{topk}_prob'),
            data_list=out_prob,
            dtype=np.float,
        )
Example #25
0
def main(args, **unused_kwargs):
    assert args['eval']['path'] is not None, '--path required for evaluation!'

    if torch.cuda.is_available() and not args['common']['cpu']:
        torch.cuda.set_device(args['distributed_training']['device_id'])

    LOGGER.info(args)
    # while evaluation, set fraction_using_func_name = 0, namely, not sample from func_name
    args['task']['fraction_using_func_name'] = 0.
    use_cuda = torch.cuda.is_available() and not args['common']['cpu']
    if use_cuda:
        device = os.environ.get('CUDA_VISIBALE_DEVICES',
                                [0])[0]  # get first device as default
        torch.cuda.set_device(f'cuda:{device}')
    task = tasks.setup_task(args)

    # Load ensemble
    LOGGER.info('loading model(s) from {}'.format(args['eval']['path']))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args['eval']['path']),
        arg_overrides=eval(args['eval']['model_overrides']),
        task=task,
    )

    for lang in deepcopy(args['dataset']['langs']):
        args['dataset']['langs'] = [lang]
        # Load dataset splits
        LOGGER.info(f'Evaluating {lang} dataset')
        task.load_dataset(args['dataset']['gen_subset'])
        dataset = task.dataset(args['dataset']['gen_subset'])

        # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
        for model in models:
            model.make_generation_fast_()
            if args['common']['fp16']:
                model.half()
            if use_cuda:
                model.cuda()

        assert len(models) > 0

        LOGGER.info('num. model params: {}'.format(
            sum(p.numel() for p in models[0].parameters())))

        itr = task.get_batch_iterator(
            dataset=dataset,
            max_tokens=args['dataset']['max_tokens'] or 36000,
            max_sentences=args['eval']['max_sentences'],
            max_positions=utils.resolve_max_positions(
                *[model.max_positions() for model in models]),
            ignore_invalid_inputs=True,
            num_shards=args['dataset']['num_shards'],
            shard_id=args['dataset']['shard_id'],
            num_workers=args['dataset']['num_workers'],
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args['common']['log_format'],
            log_interval=args['common']['log_interval'],
            default_log_format=('tqdm' if not args['common']['no_progress_bar']
                                else 'none'),
        )

        code_reprs, query_reprs = [], []
        for sample in progress:
            if 'net_input' not in sample:
                continue
            sample = move_to_cuda(sample) if use_cuda else sample
            batch_code_reprs, batch_query_reprs = models[0](
                **sample['net_input'])

            if use_cuda:
                batch_code_reprs = batch_code_reprs.cpu().detach()
                batch_query_reprs = batch_query_reprs.cpu().detach()

            code_reprs.append(batch_code_reprs)
            query_reprs.append(batch_query_reprs)
        code_reprs = torch.cat(code_reprs, dim=0)
        query_reprs = torch.cat(query_reprs, dim=0)

        assert code_reprs.shape == query_reprs.shape, (code_reprs.shape,
                                                       query_reprs.shape)
        eval_size = len(
            code_reprs
        ) if args['eval']['eval_size'] == -1 else args['eval']['eval_size']

        k, MRR, topk_idx, topk_prob = 3, [], [], []
        for idx in range(len(dataset) // eval_size):
            code_emb = code_reprs[idx:idx + eval_size, :]
            query_emb = query_reprs[idx:idx + eval_size, :]

            if use_cuda:
                code_emb = code_emb.cuda()
                query_emb = query_emb.cuda()

            if args['criterion'] == 'search_cosine':
                src_emb_nrom = torch.norm(code_emb, dim=-1,
                                          keepdim=True) + 1e-10
                tgt_emb_nrom = torch.norm(query_emb, dim=-1,
                                          keepdim=True) + 1e-10
                logits = (query_emb / tgt_emb_nrom) @ (code_emb /
                                                       src_emb_nrom).t()
            elif args['criterion'] == 'search_softmax':
                logits = query_emb @ code_emb.t()
            else:
                raise NotImplementedError

            correct_scores = logits.diag()
            compared_scores = logits >= correct_scores.unsqueeze(dim=-1)
            mrr = 1 / compared_scores.sum(dim=-1).float()
            MRR.extend(mrr.tolist())

        if len(dataset) % eval_size:
            code_emb = code_reprs[-eval_size:, :]
            query_emb = query_reprs[-eval_size:, :]

            if use_cuda:
                code_emb = code_emb.cuda()
                query_emb = query_emb.cuda()

            if args['criterion'] == 'search_cosine':
                src_emb_nrom = torch.norm(code_emb, dim=-1,
                                          keepdim=True) + 1e-10
                tgt_emb_nrom = torch.norm(query_emb, dim=-1,
                                          keepdim=True) + 1e-10
                logits = (query_emb / tgt_emb_nrom) @ (code_emb /
                                                       src_emb_nrom).t()
            elif args['criterion'] == 'search_softmax':
                logits = query_emb @ code_emb.t()
            else:
                raise NotImplementedError

            correct_scores = logits.diag()
            compared_scores = logits >= correct_scores.unsqueeze(dim=-1)
            last_ids = len(code_reprs) % eval_size
            mrr = 1 / compared_scores.sum(dim=-1).float()[-last_ids:]
            MRR.extend(mrr.tolist())

        print('{}, mrr: {:.4f}'.format(lang, np.mean(MRR)))
Example #26
0
def main(args, **unused_kwargs):
    assert args['eval']['path'] is not None, '--path required for evaluation!'

    if torch.cuda.is_available() and not args['common']['cpu']:
        torch.cuda.set_device(args['distributed_training']['device_id'])

    LOGGER.info(args)
    # while evaluation, set fraction_using_func_name = 0, namely, not sample from func_name
    args['task']['fraction_using_func_name'] = 0.
    use_cuda = torch.cuda.is_available() and not args['common']['cpu']
    task = tasks.setup_task(args)

    # Load ensemble
    LOGGER.info('loading model(s) from {}'.format(args['eval']['path']))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args['eval']['path']),
        arg_overrides=eval(args['eval']['model_overrides']),
        task=task,
    )

    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args['dataset']['gen_subset'])
    dataset = task.dataset(args['dataset']['gen_subset'])

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args['common']['fp16']:
            model.half()
        if use_cuda:
            model.cuda()

    assert len(models) > 0

    LOGGER.info('num. model params: {}'.format(
        sum(p.numel() for p in models[0].parameters())))

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args['dataset']['max_tokens'] or 36000,
        max_sentences=args['eval']['max_sentences'],
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=True,
        num_shards=args['dataset']['num_shards'],
        shard_id=args['dataset']['shard_id'],
        num_workers=args['dataset']['num_workers'],
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args['common']['log_format'],
        log_interval=args['common']['log_interval'],
        default_log_format=('tqdm' if not args['common']['no_progress_bar']
                            else 'none'),
    )

    code_reprs, query_reprs = [], []
    for sample in progress:
        if 'net_input' not in sample:
            continue
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        batch_code_reprs, batch_query_reprs = models[0](**sample['net_input'])

        code_reprs.extend(batch_code_reprs.tolist())
        query_reprs.extend(batch_query_reprs.tolist())
    code_reprs = np.asarray(code_reprs, dtype=np.float32)
    query_reprs = np.asarray(query_reprs, dtype=np.float32)

    assert code_reprs.shape == query_reprs.shape, (code_reprs.shape,
                                                   query_reprs.shape)
    eval_size = len(
        code_reprs
    ) if args['eval']['eval_size'] == -1 else args['eval']['eval_size']

    k, MRR, topk_idx, topk_prob = 3, [], [], []
    for idx in range(len(dataset) // eval_size):
        code_emb = torch.from_numpy(code_reprs[idx:idx + eval_size, :]).cuda()
        query_emb = torch.from_numpy(query_reprs[idx:idx +
                                                 eval_size, :]).cuda()
        logits = query_emb @ code_emb.t()

        # src_emb_nrom = torch.norm(code_emb, dim=-1, keepdim=True) + 1e-10
        # tgt_emb_nrom = torch.norm(query_emb, dim=-1, keepdim=True) + 1e-10
        # logits = (query_emb / tgt_emb_nrom) @ (code_emb / src_emb_nrom).t()

        correct_scores = logits.diag()
        compared_scores = logits >= correct_scores.unsqueeze(dim=-1)
        mrr = 1 / compared_scores.sum(dim=-1).float()
        MRR.extend(mrr.tolist())
        batch_topk_prob, batch_topk_idx = logits.softmax(dim=-1).topk(k)
        batch_topk_idx = batch_topk_idx + idx * eval_size
        topk_idx.extend(batch_topk_idx.tolist())
        topk_prob.extend(batch_topk_prob.tolist())

    if len(dataset) % eval_size:
        code_emb = torch.from_numpy(code_reprs[-eval_size:, :]).cuda()
        query_emb = torch.from_numpy(query_reprs[-eval_size:, :]).cuda()
        logits = query_emb @ code_emb.t()

        # src_emb_nrom = torch.norm(code_emb, dim=-1, keepdim=True) + 1e-10
        # tgt_emb_nrom = torch.norm(query_emb, dim=-1, keepdim=True) + 1e-10
        # logits = (query_emb / tgt_emb_nrom) @ (code_emb / src_emb_nrom).t()

        correct_scores = logits.diag()
        compared_scores = logits >= correct_scores.unsqueeze(dim=-1)
        last_ids = len(code_reprs) % eval_size
        mrr = 1 / compared_scores.sum(dim=-1).float()[-last_ids:]
        MRR.extend(mrr.tolist())
        batch_topk_prob, batch_topk_idx = logits[-last_ids:].softmax(
            dim=-1).topk(k)
        batch_topk_idx = batch_topk_idx + len(code_reprs) - eval_size
        topk_idx.extend(batch_topk_idx.tolist())
        topk_prob.extend(batch_topk_prob.tolist())

    print('mrr: {:.4f}'.format(np.mean(MRR)))

    for idx, mrr in enumerate(MRR):
        if mrr == 1.0 and topk_prob[idx][0] > 0.8:
            print(
                np.asarray(topk_idx[idx]) + 1,
                [round(porb, 4) for porb in topk_prob[idx]])
Example #27
0
 def setup_task(cls, args, **kwargs):
     paths = utils.split_paths(args['task']['data'])
     assert len(paths) > 0
     dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt'))
     LOGGER.info('dictionary: {} types'.format(len(dictionary)))
     return cls(args, dictionary)
Example #28
0
def _main(args, output_file):
    if args['dataset']['max_tokens'] is None and args['dataset']['max_sentences'] is None:
        args['dataset']['max_tokens'] = 12000
    LOGGER.info(args)

    use_cuda = torch.cuda.is_available() and not args['common']['cpu']

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args['dataset']['gen_subset'])

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    LOGGER.info('loading model(s) from {}'.format(args['eval']['path']))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args['eval']['path']),
        arg_overrides=eval(args['eval']['model_overrides']),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        if _model_args['common']['fp16']:
            model.half()
        if use_cuda:
            model.cuda()

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args['dataset']['gen_subset']),
        max_tokens=args['dataset']['max_tokens'],
        max_sentences=args['dataset']['max_sentences'],
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=_model_args['dataset']['skip_invalid_size_inputs_valid_test'],
        required_batch_size_multiple=_model_args['dataset']['required_batch_size_multiple'],
        num_shards=_model_args['dataset']['num_shards'],
        shard_id=_model_args['dataset']['shard_id'],
        num_workers=_model_args['dataset']['num_workers'],
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=_model_args['common']['log_format'],
        log_interval=_model_args['common']['log_interval'],
        default_log_format=('tqdm' if not _model_args['common']['no_progress_bar'] else 'none'),
    )

    """
    nohup python -m run.completion.seqrnn.eval > run/completion/seqrnn/case.log 2>&1 &
    """
    sequence_completor = task.build_completor([model], args)
    for sample in progress:
        torch.cuda.empty_cache()
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        non_pad_idx = sample['net_input']['src_tokens'] > task.target_dictionary.pad()

        with torch.no_grad():
            net_output = sequence_completor.generate([model], sample, prefix_tokens=None)
        lprobs = model.get_normalized_probs(net_output, log_probs=True)

        # from ipdb import set_trace
        # set_trace()

        rank = torch.argmax(lprobs, dim=-1)
        target = model.get_targets(sample, net_output)
        accuracy = 1.0 * ((rank == target) & non_pad_idx).sum(dim=-1) / non_pad_idx.sum(dim=-1)
        for idx, (data_idx, acc) in enumerate(zip(sample['id'], accuracy)):
            if acc > 0.9:
                LOGGER.info(f"{data_idx}: {task.target_dictionary.string(sample['net_input']['src_tokens'][idx, :])}")
Example #29
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """

        paths = utils.split_paths(self.args['task']['data'])
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            path=split_path,
            dictionary=self.source_dictionary,
            dataset_impl=self.args['dataset']['dataset_impl'],
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(
                split, split_path))

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args['task']['tokens_per_sample'] - 1,  # one less for <s>
            pad=self.source_dictionary.pad(),
            eos=self.source_dictionary.eos(),
            break_mode=self.args['task']['sample_break_mode'],
        )
        LOGGER.info('loaded {} blocks from: {}'.format(len(dataset),
                                                       split_path))

        # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
        dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())

        # create masked input and targets
        mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
            if self.args['task']['mask_whole_words'] else None

        src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
            dataset,
            self.source_dictionary,
            pad_idx=self.source_dictionary.pad(),
            mask_idx=self.mask_idx,
            seed=self.args['common']['seed'],
            mask_prob=self.args['task']['mask_prob'],
            leave_unmasked_prob=self.args['task']['leave_unmasked_prob'],
            random_token_prob=self.args['task']['random_token_prob'],
            freq_weighted_replacement=self.args['task']
            ['freq_weighted_replacement'],
            mask_whole_words=mask_whole_words,
        )

        with data_utils.numpy_seed(self.args['common']['seed'] + epoch):
            shuffle = np.random.permutation(len(src_dataset))

        self.datasets[split] = SortDataset(
            NestedDictionaryDataset(
                {
                    'id':
                    IdDataset(),
                    'net_input': {
                        'src_tokens':
                        PadDataset(
                            src_dataset,
                            pad_idx=self.source_dictionary.pad(),
                            left_pad=False,
                        ),
                        'src_lengths':
                        NumelDataset(src_dataset, reduce=False),
                    },
                    'target':
                    PadDataset(
                        tgt_dataset,
                        pad_idx=self.source_dictionary.pad(),
                        left_pad=False,
                    ),
                    'nsentences':
                    NumSamplesDataset(),
                    'ntokens':
                    NumelDataset(src_dataset, reduce=True),
                },
                sizes=[src_dataset.sizes],
            ),
            sort_order=[
                shuffle,
                src_dataset.sizes,
            ],
        )
Example #30
0
def _main(args, output_file):
    if args['dataset']['max_tokens'] is None and args['dataset'][
            'max_sentences'] is None:
        args['dataset']['max_tokens'] = 12000
    LOGGER.info(args)

    use_cuda = torch.cuda.is_available() and not args['common']['cpu']

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args['dataset']['gen_subset'])

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    LOGGER.info('loading model(s) from {}'.format(args['eval']['path']))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args['eval']['path']),
        arg_overrides=eval(args['eval']['model_overrides']),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None
            if args['eval']['no_beamable_mm'] else args['eval']['beam'],
            need_attn=args['eval']['print_alignment'],
        )
        if _model_args['common']['fp16']:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args['eval']['replace_unk'])

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args['dataset']['gen_subset']),
        max_tokens=args['dataset']['max_tokens'],
        max_sentences=args['eval']['max_sentences'],
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=_model_args['dataset']
        ['skip_invalid_size_inputs_valid_test'],
        required_batch_size_multiple=_model_args['dataset']
        ['required_batch_size_multiple'],
        num_shards=_model_args['dataset']['num_shards'],
        shard_id=_model_args['dataset']['shard_id'],
        num_workers=_model_args['dataset']['num_workers'],
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=_model_args['common']['log_format'],
        log_interval=_model_args['common']['log_interval'],
        default_log_format=('tqdm'
                            if not _model_args['common']['no_progress_bar']
                            else 'none'),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    scorer = OrderedDict()
    if args['eval']['sacrebleu']:
        scorer['bleu'] = bleu_scorer.SacrebleuScorer()
    elif args['eval']['nltk_bleu']:
        scorer['bleu'] = bleu_scorer.NLTKBleuScorer()
    else:
        scorer['bleu'] = bleu_scorer.Scorer(tgt_dict.pad(), tgt_dict.eos(),
                                            tgt_dict.unk())
    # Generate and compute BLEU score
    if args['eval']['rouge']:
        scorer['rouge'] = rouge_scorer.RougeScorer()
    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    # for sample in tqdm(progress, total=len(progress)):
    for sample in progress:
        torch.cuda.empty_cache()
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        prefix_tokens = None
        if args['eval']['prefix_size'] > 0:
            prefix_tokens = sample['target'][:, :args['eval']['prefix_size']]

        gen_timer.start()
        hypos = task.inference_step(generator, models, sample, prefix_tokens)
        num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
        gen_timer.stop(num_generated_tokens)

        for i, sample_id in enumerate(sample['id'].tolist()):
            has_target = sample['target'] is not None

            # Remove padding
            src_tokens = utils.strip_pad(
                sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
            target_tokens = None
            if has_target:
                target_tokens = utils.strip_pad(sample['target'][i, :],
                                                tgt_dict.pad()).int().cpu()

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(
                    args['dataset']['gen_subset']).src.get_original_text(
                        sample_id)
                target_str = task.dataset(
                    args['dataset']['gen_subset']).tgt.get_original_text(
                        sample_id)
            else:
                if src_dict is not None:
                    src_str = src_dict.string(src_tokens,
                                              args['eval']['remove_bpe'])
                else:
                    src_str = ""
                if has_target:
                    target_str = tgt_dict.string(target_tokens,
                                                 args['eval']['remove_bpe'],
                                                 escape_unk=True)

            if not args['eval']['quiet']:
                if src_dict is not None:
                    print('S-{}\t{}'.format(sample_id, src_str),
                          file=output_file)
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str),
                          file=output_file)

            # Process top predictions
            for j, hypo in enumerate(hypos[i][:args['eval']['nbest']]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args['eval']['remove_bpe'],
                )

                if hypo_str == '.':
                    # rouge cannot handle hypo'.'
                    continue

                if not args['eval']['quiet']:
                    score = hypo['score'] / math.log(2)  # convert to base 2
                    print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str),
                          file=output_file)
                    print(
                        'P-{}\t{}'.format(
                            sample_id,
                            ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    # convert from base e to base 2
                                    hypo['positional_scores'].div_(math.log(2)
                                                                   ).tolist(),
                                ))),
                        file=output_file)

                    if args['eval']['print_alignment']:
                        print('A-{}\t{}'.format(
                            sample_id, ' '.join([
                                '{}-{}'.format(src_idx, tgt_idx)
                                for src_idx, tgt_idx in alignment
                            ])),
                              file=output_file)

                    if args['eval']['print_step']:
                        print('I-{}\t{}'.format(sample_id, hypo['steps']),
                              file=output_file)

                    # if getattr(args, 'retain_iter_history', False):
                    if args['eval']['retain_iter_history']:
                        for step, h in enumerate(hypo['history']):
                            _, h_str, _ = utils.post_process_prediction(
                                hypo_tokens=h['tokens'].int().cpu(),
                                src_str=src_str,
                                alignment=None,
                                align_dict=None,
                                tgt_dict=tgt_dict,
                                remove_bpe=None,
                            )
                            print('E-{}_{}\t{}'.format(sample_id, step, h_str),
                                  file=output_file)

                # Score only the top hypothesis
                if has_target and j == 0:
                    # print('Ref>> {}'.format(target_str), file=output_file)
                    # print('Hyp>> {}'.format(hypo_str), file=output_file)
                    if align_dict is not None or args['eval'][
                            'remove_bpe'] is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tgt_dict.encode_line(
                            target_str, add_if_not_exist=True)
                    for metric in scorer:
                        if hasattr(scorer[metric], 'add_string'):
                            scorer[metric].add_string(target_str, hypo_str)
                        else:
                            scorer[metric].add(target_tokens, hypo_tokens)

        wps_meter.update(num_generated_tokens)
        progress.log({'wps': round(wps_meter.avg)})
        num_sentences += sample['nsentences']

    LOGGER.info('NOTE: hypothesis and token scores are output in base 2')
    LOGGER.info(
        'Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        LOGGER.info('Generate {} with beam={}: {}'.format(
            args['dataset']['gen_subset'], args['eval']['beam'], {
                '\n{}:\n{}'.format(str.upper(metric), value.score())
                for metric, value in scorer.items()
            }))

    return scorer