Esempio n. 1
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""
        def get_path(type, split):
            return os.path.join(self.args.data, type, split)

        def make_dataset(type, dictionary):
            split_path = get_path(type, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            return dataset

        input0 = make_dataset('input0', self.source_dictionary)
        assert input0 is not None, 'could not find dataset: {}'.format(get_path(type, split))
        input1 = make_dataset('input1', self.source_dictionary)

        if self.args.init_token is not None:
            input0 = PrependTokenDataset(input0, self.args.init_token)

        if input1 is None:
            src_tokens = input0
        else:
            if self.args.separator_token is not None:
                input1 = PrependTokenDataset(input1, self.args.separator_token)

            src_tokens = ConcatSentencesDataset(input0, input1)

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(src_tokens))

        src_tokens = maybe_shorten_dataset(
            src_tokens,
            split,
            self.args.shorten_data_split_whitelist,
            self.args.shorten_method,
            self.args.max_positions,
            self.args.seed,
        )

        dataset = {
            'id': IdDataset(),
            'net_input': {
                'src_tokens': RightPadDataset(
                    src_tokens,
                    pad_idx=self.source_dictionary.pad(),
                ),
                'src_lengths': NumelDataset(src_tokens, reduce=False),
            },
            'nsentences': NumSamplesDataset(),
            'ntokens': NumelDataset(src_tokens, reduce=True),
        }

        if self.args.add_prev_output_tokens:
            prev_tokens_dataset = RightPadDataset(
                RollDataset(src_tokens, 1),
                pad_idx=self.dictionary.pad(),
            )
            dataset['net_input'].update(
                prev_output_tokens=prev_tokens_dataset,
            )

        if not self.args.regression_target:
            label_dataset = make_dataset('label', self.label_dictionary)
            if label_dataset is not None:
                dataset.update(
                    target=OffsetTokensDataset(
                        StripTokenDataset(
                            label_dataset,
                            id_to_strip=self.label_dictionary.eos(),
                        ),
                        offset=-self.label_dictionary.nspecial,
                    )
                )
        else:
            label_path = "{0}.label".format(get_path('label', split))
            if os.path.exists(label_path):
                def parse_regression_target(i, line):
                    values = line.split()
                    assert len(values) == self.args.num_classes, \
                        f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"'
                    return [float(x) for x in values]
                dataset.update(
                    target=RawLabelDataset([
                        parse_regression_target(i, line.strip()) for i, line in enumerate(open(label_path).readlines())
                    ])
                )

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[src_tokens.sizes],
        )

        if self.args.no_shuffle:
            dataset = nested_dataset
        else:
            dataset = SortDataset(
                nested_dataset,
                # shuffle
                sort_order=[shuffle],
            )

        logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset)))

        self.datasets[split] = dataset
        return self.datasets[split]
Esempio n. 2
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""
        def get_path(key, split):
            return os.path.join(self.cfg.data, key, split)

        def make_dataset(key, dictionary):
            split_path = get_path(key, split)

            try:
                dataset = data_utils.load_indexed_dataset(
                    split_path,
                    dictionary,
                    combine=combine,
                )
            except Exception as e:
                if "StorageException: [404] Path not found" in str(e):
                    logger.warning(f"dataset {e} not found")
                    dataset = None
                else:
                    raise e
            return dataset

        input0 = make_dataset("input0", self.source_dictionary)
        assert input0 is not None, "could not find dataset: {}".format(
            get_path("input0", split))
        input1 = make_dataset("input1", self.source_dictionary)

        if self.cfg.init_token is not None:
            input0 = PrependTokenDataset(input0, self.cfg.init_token)

        if input1 is None:
            src_tokens = input0
        else:
            if self.cfg.separator_token is not None:
                input1 = PrependTokenDataset(input1, self.cfg.separator_token)

            src_tokens = ConcatSentencesDataset(input0, input1)

        with data_utils.numpy_seed(self.cfg.seed):
            shuffle = np.random.permutation(len(src_tokens))

        src_tokens = maybe_shorten_dataset(
            src_tokens,
            split,
            self.cfg.shorten_data_split_list,
            self.cfg.shorten_method,
            self.max_positions(),
            self.cfg.seed,
        )

        dataset = {
            "id": IdDataset(),
            "net_input": {
                "src_tokens":
                RightPadDataset(
                    src_tokens,
                    pad_idx=self.source_dictionary.pad(),
                ),
                "src_lengths":
                NumelDataset(src_tokens, reduce=False),
            },
            "nsentences": NumSamplesDataset(),
            "ntokens": NumelDataset(src_tokens, reduce=True),
        }

        if self.cfg.add_prev_output_tokens:
            prev_tokens_dataset = RightPadDataset(
                RollDataset(src_tokens, 1),
                pad_idx=self.dictionary.pad(),
            )
            dataset["net_input"].update(
                prev_output_tokens=prev_tokens_dataset, )

        if not self.cfg.regression_target:
            label_dataset = make_dataset("label", self.label_dictionary)
            if label_dataset is not None:
                dataset.update(target=OffsetTokensDataset(
                    StripTokenDataset(
                        label_dataset,
                        id_to_strip=self.label_dictionary.eos(),
                    ),
                    offset=-self.label_dictionary.nspecial,
                ))
        else:
            label_path = "{0}.label".format(get_path("label", split))
            if os.path.exists(label_path):

                def parse_regression_target(i, line):
                    values = line.split()
                    assert (
                        len(values) == self.cfg.num_classes
                    ), f'expected num_classes={self.cfg.num_classes} regression target values on line {i}, found: "{line}"'
                    return [float(x) for x in values]

                with open(label_path) as h:
                    dataset.update(target=RawLabelDataset([
                        parse_regression_target(i, line.strip())
                        for i, line in enumerate(h.readlines())
                    ]))

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[src_tokens.sizes],
        )

        if self.cfg.no_shuffle:
            dataset = nested_dataset
        else:
            dataset = SortDataset(
                nested_dataset,
                # shuffle
                sort_order=[shuffle],
            )

        logger.info("Loaded {0} with #samples: {1}".format(
            split, len(dataset)))

        self.datasets[split] = dataset
        return self.datasets[split]
Esempio n. 3
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""
        def get_path(type, split):
            return os.path.join(self.args.data, type, split)

        def make_dataset(type, dictionary):
            split_path = get_path(type, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            return dataset

        src_tokens = make_dataset('data', self.source_dictionary)

        if self.args.init_token is not None:
            src_tokens = PrependTokenDataset(src_tokens, self.args.init_token)

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(src_tokens))

        if self.args.truncate_sequence:
            src_tokens = TruncateDataset(src_tokens, self.args.max_positions)

        dataset = {
            'id': IdDataset(),
            'net_input': {
                'src_tokens':
                RightPadDataset(
                    src_tokens,
                    pad_idx=self.source_dictionary.pad(),
                ),
                'src_lengths':
                NumelDataset(src_tokens, reduce=False),
            },
            'nsentences': NumSamplesDataset(),
            'ntokens': NumelDataset(src_tokens, reduce=True),
        }

        label_dataset = make_dataset('label', self.label_dictionary)
        if label_dataset is not None:
            dataset.update(target=OffsetTokensDataset(
                StripTokenDataset(
                    label_dataset,
                    id_to_strip=self.label_dictionary.eos(),
                ),
                offset=-self.label_dictionary.nspecial,
            ))

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[src_tokens.sizes],
        )

        if self.args.no_shuffle:
            dataset = nested_dataset
        else:
            dataset = SortDataset(
                nested_dataset,
                # shuffle
                sort_order=[shuffle],
            )

        logger.info("Loaded {0} with #samples: {1}".format(
            split, len(dataset)))

        self.datasets[split] = dataset
        return self.datasets[split]
Esempio n. 4
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""
        def get_path(type, split):
            return os.path.join(self.args.data, type, split)

        def make_dataset(type, dictionary):
            split_path = get_path(type, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                self.source_dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            return dataset

        input0 = make_dataset('input0', self.source_dictionary)
        assert input0 is not None, 'could not find dataset: {}'.format(
            get_path(type, split))
        input1 = make_dataset('input1', self.source_dictionary)

        if self.args.init_token is not None:
            input0 = PrependTokenDataset(input0, self.args.init_token)

        if input1 is None:
            src_tokens = input0
        else:
            if self.args.separator_token is not None:
                input1 = PrependTokenDataset(input1, self.args.separator_token)

            src_tokens = ConcatSentencesDataset(input0, input1)

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(src_tokens))

        if self.args.truncate_sequence:
            src_tokens = TruncateDataset(src_tokens, self.args.max_positions)

        dataset = {
            'id': IdDataset(),
            'net_input': {
                'src_tokens':
                RightPadDataset(
                    src_tokens,
                    pad_idx=self.source_dictionary.pad(),
                ),
                'src_lengths':
                NumelDataset(src_tokens, reduce=False),
            },
            'nsentences': NumSamplesDataset(),
            'ntokens': NumelDataset(src_tokens, reduce=True),
        }

        if not self.args.regression_target:
            label_dataset = make_dataset('label', self.target_dictionary)
            if label_dataset is not None:
                dataset.update(target=OffsetTokensDataset(
                    StripTokenDataset(
                        label_dataset,
                        id_to_strip=self.target_dictionary.eos(),
                    ),
                    offset=-self.target_dictionary.nspecial,
                ))
        else:
            label_path = f"{get_path('label', split)}.label"
            if os.path.exists(label_path):
                dataset.update(target=RawLabelDataset(
                    [float(x.strip()) for x in open(label_path).readlines()]))

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[src_tokens.sizes],
        )

        if self.args.no_shuffle:
            dataset = nested_dataset
        else:
            dataset = SortDataset(
                nested_dataset,
                # shuffle
                sort_order=[shuffle],
            )

        print(f"| Loaded {split} with #samples: {len(dataset)}")

        self.datasets[split] = dataset
        return self.datasets[split]
Esempio n. 5
0
        def lang_dataset(lang):
            input0 = make_dataset('input0', lang, self.source_dictionary)
            assert input0 is not None, 'could not find dataset: {}'.format(
                get_path('input0', lang, split))
            input1 = make_dataset('input1', lang, self.source_dictionary)

            if self.args.init_token is not None:
                input0 = PrependTokenDataset(input0, self.args.init_token)

            if input1 is None:
                src_tokens = input0
            else:
                if self.args.separator_token is not None:
                    input1 = PrependTokenDataset(input1,
                                                 self.args.separator_token)

                src_tokens = ConcatSentencesDataset(input0, input1)

            with data_utils.numpy_seed(self.args.seed):
                shuffle = np.random.permutation(len(src_tokens))

            if self.args.truncate_sequence:
                src_tokens = TruncateDataset(src_tokens,
                                             self.args.max_positions)

            dataset = {
                'id': IdDataset(),
                'net_input': {
                    'src_tokens':
                    RightPadDataset(
                        src_tokens,
                        pad_idx=self.source_dictionary.pad(),
                    ),
                    'src_lengths':
                    NumelDataset(src_tokens, reduce=False),
                },
                'nsentences': NumSamplesDataset(),
                'ntokens': NumelDataset(src_tokens, reduce=True),
            }

            if not self.args.regression_target:
                label_dataset = make_dataset('label', lang,
                                             self.target_dictionary)
                if label_dataset is not None:
                    dataset.update(target=OffsetTokensDataset(
                        StripTokenDataset(
                            label_dataset,
                            id_to_strip=self.target_dictionary.eos(),
                        ),
                        offset=-self.target_dictionary.nspecial,
                    ))
            else:
                label_path = "{0}.label".format(get_path('label', lang, split))
                if os.path.exists(label_path):
                    dataset.update(target=RawLabelDataset([
                        float(x.strip()) for x in open(label_path).readlines()
                    ]))

            nested_dataset = NestedDictionaryDataset(
                dataset,
                sizes=[src_tokens.sizes],
            )

            if self.args.no_shuffle:
                dataset = nested_dataset
            else:
                dataset = SortDataset(
                    nested_dataset,
                    # shuffle
                    sort_order=[shuffle],
                )

            print("| Loaded {0} with #samples: {1}".format(
                split, len(dataset)))
            return dataset
Esempio n. 6
0
def load_glue_data(
        task_path,
        mydict,
        mode='train'
):  #一个大列表,每个item是一个文档矩阵,矩阵里面每个item是一个node的数值  ,for token_id 和
    # dataset = data_utils.load_indexed_dataset(path,mydict,'mmap',combine=False,)
    # dataset = TokenBlockDataset(dataset,dataset.sizes,512 - 1,pad=mydict.pad(),eos=mydict.eos(), break_mode='complete',)
    # dataset = PrependTokenDataset(dataset, mydict.bos())
    #dataset=[]
    #input1=open(input_path1,'r').readlines()#[:10000]

    #label=open(label_path,'r').readlines()

    input0 = data_utils.load_indexed_dataset(
        os.path.join(task_path, 'input0', mode),
        mydict,
        'mmap',
        combine=False,
    )
    assert input0 is not None, 'could not find dataset: {}'.format(
        get_path(type, split))

    input1 = data_utils.load_indexed_dataset(
        os.path.join(task_path, 'input1', mode),
        mydict,
        'mmap',
        combine=False,
    )
    input0 = PrependTokenDataset(input0, mydict.bos())
    if input1 is None:
        src_tokens = input0
    else:
        input1 = PrependTokenDataset(input1, mydict.eos())
        src_tokens = ConcatSentencesDataset(input0, input1)

    if not 'STS-B' in task_path:
        label_dictionary = Dictionary.load(
            os.path.join(task_path, 'label', 'dict.txt'))
        label_dictionary.add_symbol('<mask>')
        #label_dataset = make_dataset('label', label_dictionary)
        label_dataset = data_utils.load_indexed_dataset(
            os.path.join(task_path, 'label', mode),
            label_dictionary,
            'mmap',
            combine=False,
        )
        if label_dataset is not None:

            label = OffsetTokensDataset(
                StripTokenDataset(
                    label_dataset,
                    id_to_strip=label_dictionary.eos(),
                ),
                offset=-label_dictionary.nspecial,
            )
    else:
        label_path = "{0}.label".format(os.path.join(task_path, 'label', mode))
        if os.path.exists(label_path):

            def parse_regression_target(i, line):
                values = line.split()
                assert len(values) == 1, \
                    f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"'
                return [float(x) for x in values]

            with open(label_path) as h:

                label = RawLabelDataset([
                    parse_regression_target(i, line.strip())
                    for i, line in enumerate(h.readlines())
                ])
    print('data size: ', len(src_tokens), len(label))
    assert len(src_tokens) == len(label)
    # with data_utils.numpy_seed(self.args.seed):
    #     shuffle = np.random.permutation(len(src_tokens))

    # src_tokens = maybe_shorten_dataset(
    #     src_tokens,
    #     split,
    #     self.args.shorten_data_split_list,
    #     self.args.shorten_method,
    #     self.args.max_positions,
    #     self.args.seed,
    # )
    # input_data1=[]
    # input_data2=[]
    #label_list=[]
    # for line in input1:
    #     if len(line.strip())==0:
    #         input_data1.append([])
    #     else:
    #         line = line.strip().split(' ')
    #         input_data1.append([int(x) for x in line])
    # if input_path2:
    #     input2=open(input_path2,'r').readlines()
    #     for line in input2:
    #         if len(line.strip())==0:
    #             input_data2.append([])
    #         else:
    #             line = line.strip().split(' ')
    #             input_data2.append([int(x) for x in line])
    # if task=='QNLI':
    #     for line in label:
    #         line = line.strip()
    #         if line=='entailment':
    #             label_list.append(int(1))
    #         else:
    #             assert line=='not_entailment'
    #             label_list.append(int(0))
    # else:
    #     for line in label:
    #         line = line.strip()
    #         label_list.append(int(line))
    # print('data length: ',len(input_data1),len(input_data2))
    # assert len(input_data1)==len(label_list)
    # if len(input_data2)!=0:
    #     assert len(input_data1)==len(input_data2)

    return src_tokens, label
Esempio n. 7
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""

        def get_path(type, split):
            return os.path.join(self.args.data, type, split)

        def make_dataset(type, dictionary):
            split_path = get_path(type, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            assert dataset is not None, "could not find dataset: {}".format(
                get_path(type, split))
            return dataset

        src_tokens = make_dataset("input0", self.source_dictionary)
        pos_tokens = make_dataset("input1", self.pos_dictionary)

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(src_tokens))

        label0_dataset = make_dataset("label0", self.label0_dictionary)
        label1_dataset = make_dataset("label1", self.label1_dictionary)

        dataset = {
            "id": IdDataset(),
            "net_input": {
                "src_tokens": RightPadDataset(
                    src_tokens,
                    pad_idx=self.source_dictionary.pad(),
                    pad_to_length=self._max_positions,
                ),
                "src_lengths": NumelDataset(src_tokens, reduce=False),
            },
            "segments": {
                "seg_tokens": RightPadDataset(
                    pos_tokens,
                    pad_idx=self.pos_dictionary.pad(),
                    pad_to_length=self._max_positions,
                ),
                "seg_lengths": NumelDataset(pos_tokens, reduce=False),
            },
            "target0": RightPadDataset(  # use 1 as padding, will be used to mask out padding when calculating loss
                ReplaceDataset(  # replace eos and existing padding (used when some tokens should not be predicted) with -1
                    OffsetTokensDataset(  # offset tokens to get the targets to the correct range (0,1,2,...)
                        label0_dataset,
                        offset=-self.label0_dictionary.nspecial,
                    ),
                    replace_map={
                        self.label0_dictionary.eos()
                        - self.label0_dictionary.nspecial: -1,
                        self.label0_dictionary.pad()
                        - self.label0_dictionary.nspecial: -1,
                    },
                    offsets=np.zeros(len(label0_dataset), dtype=np.int),
                ),
                pad_idx=-1,
                pad_to_length=self._max_positions,
            ),
            "target1": RightPadDataset(  # use 1 as padding, will be used to mask out padding when calculating loss
                ReplaceDataset(  # replace eos and existing padding (used when some tokens should not be predicted) with -1
                    OffsetTokensDataset(  # offset tokens to get the targets to the correct range (0,1,2,...)
                        label1_dataset,
                        offset=-self.label1_dictionary.nspecial,
                    ),
                    replace_map={
                        self.label1_dictionary.eos()
                        - self.label1_dictionary.nspecial: -1,
                        self.label1_dictionary.pad()
                        - self.label1_dictionary.nspecial: -1,
                    },
                    offsets=np.zeros(len(label1_dataset), dtype=np.int),
                ),
                pad_idx=-1,
                pad_to_length=self._max_positions,
            ),
            "nsentences": NumSamplesDataset(),
            "ntokens": NumelDataset(src_tokens, reduce=True),
        }

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[src_tokens.sizes],
        )

        if self.args.no_shuffle:
            dataset = nested_dataset
        else:
            dataset = SortDataset(
                nested_dataset,
                # shuffle
                sort_order=[shuffle],
            )
        logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset)))
        self.datasets[split] = dataset
        return self.datasets[split]
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.
        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.cfg.data)
        assert len(paths) > 0
        if split != getattr(self.cfg, "train_subset", None):
            # if not training data set, use the first shard for valid and test
            paths = paths[:1]
        data_path = paths[(epoch - 1) % len(paths)]

        # infer langcode
        src, tgt = self.cfg.source_lang, self.cfg.target_lang

        prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt))
        src_dataset = data_utils.load_indexed_dataset(prefix + src,
                                                      self.src_dict,
                                                      self.cfg.dataset_impl)
        tag_dataset = data_utils.load_indexed_dataset(prefix + tgt,
                                                      self.tag_dict,
                                                      self.cfg.dataset_impl)

        src_dataset = StripTokenDataset(
            src_dataset, id_to_strip=self.source_dictionary.eos())
        tag_dataset = StripTokenDataset(tag_dataset,
                                        id_to_strip=self.tag_dictionary.eos())

        tag_pad = self.source_dictionary.pad()
        tag_offset = tag_pad + 1
        dataset = {
            'id':
            IdDataset(),
            'net_input': {
                'src_tokens':
                RightPadDataset(src_dataset,
                                pad_idx=self.source_dictionary.pad()),
                'src_lengths':
                NumelDataset(src_dataset, reduce=False),
            },
            'nsentences':
            NumSamplesDataset(),
            'ntokens':
            NumelDataset(src_dataset, reduce=True),
            'target':
            RightPadDataset(
                OffsetTokensDataset(tag_dataset,
                                    offset=-self.tag_dictionary.nspecial +
                                    tag_offset),
                pad_idx=tag_pad,
            ),
        }
        dataset = NestedDictionaryDataset(
            dataset,
            sizes=[src_dataset.sizes],
        )
        logger.info(
            str([self.src_dict[k]
                 for k in dataset[0]['net_input.src_tokens']]))
        logger.info(
            str([
                self.tag_dict[k + self.tag_dictionary.nspecial - tag_offset]
                for k in dataset[0]['target']
            ]))
        self.datasets[split] = dataset
Esempio n. 9
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""
        def get_path(split):
            return os.path.join(self.args.data, split)

        def make_dataset(split_path, dictionary):
            dataset = data_utils.load_indexed_dataset(
                split_path,
                self.source_dictionary,
                self.args.dataset_impl,
                combine=combine)
            return dataset

        input0 = make_dataset(
            os.path.join(self.args.data, split), self.source_dictionary)
        assert input0 is not None, 'could not find dataset: {}'.format(
            os.path.join(self.args.data, split))

        if self.args.init_token is not None:
            input0 = PrependTokenDataset(input0, self.args.init_token)

        src_tokens = input0

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(src_tokens))

        if self.args.truncate_sequence:
            src_tokens = TruncateDataset(src_tokens, self.args.max_positions)

        dataset = {
            'id': IdDataset(),
            'net_input': {
                'src_tokens': RightPadDataset(
                    src_tokens,
                    pad_idx=self.source_dictionary.pad()),
                'src_lengths': NumelDataset(src_tokens, reduce=False)},
            'nsentences': NumSamplesDataset(),
            'ntokens': NumelDataset(src_tokens, reduce=True),
        }

        if self.args.add_prev_output_tokens:
            prev_tokens_dataset = RightPadDataset(
                RollDataset(src_tokens, 1),
                pad_idx=self.dictionary.pad())
            dataset['net_input'].update(
                prev_output_tokens=prev_tokens_dataset)

        if not self.args.regression_target:
            label_dataset = make_dataset(
                os.path.join(self.args.data, split + '.label'), self.target_dictionary)
            if label_dataset is not None:
                dataset.update(
                    target=OffsetTokensDataset(
                        StripTokenDataset(
                            label_dataset,
                            id_to_strip=self.target_dictionary.eos()),
                        offset=-self.target_dictionary.nspecial,
                    )
                )
        else:
            label_path = os.path.join(self.args.data, split + '.label')
            if os.path.exists(label_path):
                dataset.update(target=RawLabelDataset([
                    float(x.strip()) for x in open(label_path).readlines()]))

        nested_dataset = NestedDictionaryDataset(dataset, sizes=[src_tokens.sizes])

        if self.args.no_shuffle:
            dataset = nested_dataset
        else:
            dataset = SortDataset(
                nested_dataset,
                sort_order=[shuffle])  # shuffle

        print("| Loaded {0} with #samples: {1}".format(split, len(dataset)))

        self.datasets[split] = dataset
        return self.datasets[split]
Esempio n. 10
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""
        def get_path(type, split):
            return os.path.join(self.args.data, type, split)

        def make_dataset(type, dictionary):
            split_path = get_path(type, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            return dataset

        # input0 is source, input1 is synthetic target, input2 is reference
        input0 = make_dataset(self.args.input0, self.source_dictionary)
        assert input0 is not None, 'could not find dataset: {}'.format(
            get_path(type, split))
        input1 = make_dataset(self.args.input1, self.source_dictionary)

        if self.args.init_token is not None:
            input0 = PrependTokenDataset(input0, self.args.init_token)

        if self.args.input2 is not None:
            input2 = make_dataset(self.args.input2, self.source_dictionary)

        if self.args.input2 is not None and self.add_ref_prob > 0 and split != 'valid':
            input3 = PrependTokenDataset(input2, self.args.separator_token)
        else:
            input3 = None

        if input1 is None:
            src_tokens = input0
        else:
            if self.args.separator_token is not None:
                input1 = PrependTokenDataset(input1, self.args.separator_token)

            if self.args.input2 is not None and self.add_ref_prob > 0. and split != 'valid':
                src_tokens = ConcatSentencesDataset(
                    input0,
                    input3,
                    input1,
                    add_ref_prob=self.add_ref_prob,
                    drop_ref_rate=self.args.dropout_ref,
                    pad_idx=self.source_dictionary.pad(),
                    eos_idx=self.source_dictionary.eos(),
                    bos_idx=self.source_dictionary.bos())
            else:
                src_tokens = ConcatSentencesDataset(input0, input1)

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(src_tokens))

        if self.args.truncate_sequence:
            src_tokens = TruncateDataset(src_tokens, self.args.max_positions)

        if self.args.input2 is not None and self.args.add_tran_loss:
            # create masked input and targets
            mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
                if self.args.mask_whole_words else None
            ref_dataset, ref_target_dataset = MaskTokensDataset.apply_mask(
                input2,
                self.source_dictionary,
                pad_idx=self.source_dictionary.pad(),
                mask_idx=self.mask_idx,
                seed=self.args.seed,
                mask_prob=self.args.mask_prob,
                leave_unmasked_prob=self.args.leave_unmasked_prob,
                random_token_prob=self.args.random_token_prob,
                freq_weighted_replacement=self.args.freq_weighted_replacement,
                mask_whole_words=mask_whole_words,
            )

            if self.args.separator_token is not None:
                input2 = PrependTokenDataset(ref_dataset,
                                             self.args.separator_token)
            parallel_src_tokens = ConcatSentencesDataset(input0, input2)
            if self.args.truncate_sequence:
                parallel_src_tokens = TruncateDataset(parallel_src_tokens,
                                                      self.args.max_positions)

        dataset = {
            'id': IdDataset(),
            'net_input': {
                'src_tokens':
                RightPadDataset(
                    src_tokens,
                    pad_idx=self.source_dictionary.pad(),
                ),
                'src_lengths':
                NumelDataset(src_tokens, reduce=False),
            },
            'nsentences': NumSamplesDataset(),
            'ntokens': NumelDataset(src_tokens, reduce=True),
        }

        if self.args.input2 is not None and self.args.add_tran_loss:
            dataset['net_input']['parallel_src_tokens'] = RightPadDataset(
                parallel_src_tokens,
                pad_idx=self.source_dictionary.pad(),
            )

        if self.args.add_prev_output_tokens:
            prev_tokens_dataset = RightPadDataset(
                RollDataset(src_tokens, 1),
                pad_idx=self.dictionary.pad(),
            )
            dataset['net_input'].update(
                prev_output_tokens=prev_tokens_dataset, )

        if not self.args.regression_target:
            label_dataset = make_dataset('label', self.label_dictionary)
            if label_dataset is not None:
                dataset.update(target=OffsetTokensDataset(
                    StripTokenDataset(
                        label_dataset,
                        id_to_strip=self.label_dictionary.eos(),
                    ),
                    offset=-self.label_dictionary.nspecial,
                ))
            if self.args.input2 is not None and self.args.add_tran_loss:
                # used as translation target when calculating loss
                dataset.update(parallel_target=RightPadDataset(
                    ref_target_dataset,
                    pad_idx=self.source_dictionary.pad(),
                ))
        else:
            label_path = "{0}.label".format(get_path('label', split))
            if os.path.exists(label_path):

                def parse_regression_target(i, line):
                    values = line.split()
                    assert len(values) == self.args.num_classes, \
                        f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"'
                    return [float(x) for x in values]

                dataset.update(target=RawLabelDataset([
                    parse_regression_target(i, line.strip())
                    for i, line in enumerate(open(label_path).readlines())
                ]))

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[src_tokens.sizes],
            all_sizes=src_tokens.all_sizes
            if self.args.add_target_num_tokens else None,
            padding_idx=self.source_dictionary.pad(),
            add_ref_prob=self.add_ref_prob if split != 'valid' else 0.,
        )

        if self.args.no_shuffle:
            dataset = nested_dataset
        else:
            dataset = SortDataset(
                nested_dataset,
                # shuffle
                sort_order=[shuffle],
            )

        logger.info("Loaded {0} with #samples: {1}".format(
            split, len(dataset)))

        self.datasets[split] = dataset
        return self.datasets[split]
Esempio n. 11
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""
        def get_path(type, split):
            return os.path.join(self.args.data, type, split)

        def make_dataset(type, dictionary):
            split_path = get_path(type, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            return dataset

        src_tokens = make_dataset('data', self.source_dictionary)
        assert src_tokens is not None, 'could not find dataset: {}'.format(
            get_path('data', split))

        src_tokens = TruncateDataset(src_tokens, self.args.max_positions)

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(src_tokens))

        dataset = {
            'id': IdDataset(),
            'net_input': {
                'src_tokens': src_tokens,
                'src_lengths': NumelDataset(src_tokens, reduce=False),
            },
            'nsentences': NumSamplesDataset(),
            'ntokens': NumelDataset(src_tokens, reduce=True),
        }

        src_labels = make_dataset('label', self.target_dictionary)
        assert src_labels is not None, 'could not find dataset: {}'.format(
            get_path('label', split))

        src_labels = TruncateDataset(src_labels, self.args.max_positions)

        src_labels = OffsetTokensDataset(
            src_labels,
            offset=-self.target_dictionary.nspecial,
        )

        dataset.update(target=src_labels)

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[src_tokens.sizes],
        )

        if self.args.no_shuffle:
            dataset = nested_dataset
        else:
            dataset = SortDataset(
                nested_dataset,
                # shuffle
                sort_order=[shuffle],
            )

        print("| Loaded {0} with #samples: {1}".format(split, len(dataset)))

        self.datasets[split] = dataset
        return self.datasets[split]