コード例 #1
0
 def setUp(self):
     self.strings = ["ab", "c", "def", "ghij"]
     self.weights = [4.0, 2.0, 7.0, 1.5]
     self.size_ratio = 2
     self.dataset = ListDataset(
         self.strings, np.array([len(s) for s in self.strings])
     )
コード例 #2
0
    def prepare_tokens(self, tokens: torch.Tensor):
        sizes = [len(seq) for seq in tokens]
        src_tokens = ListDataset(tokens, sizes=sizes)
        src_tokens = RightPadDataset(src_tokens,
                                     pad_idx=self.source_dictionary.pad())

        word_masks_w_bos = WordEndMaskDataset(src_tokens,
                                              self.dictionary,
                                              self.is_word_initial,
                                              bos_value=1,
                                              eos_value=0)

        dataset = {
            "id":
            IdDataset(),
            "net_input": {
                "src_tokens": src_tokens,
                "nsrc_tokens": NumelDataset(src_tokens),
                "word_mask_w_bos": RightPadDataset(word_masks_w_bos,
                                                   pad_idx=0),
            },
            "ntokens":
            NumelDataset(src_tokens, reduce=True),
            "nwords":
            NumWordsDataset(src_tokens, self.dictionary, self.is_word_initial),
            "nsentences":
            NumSamplesDataset(),
        }
        dataset = NestedDictionaryDatasetFix(dataset, sizes=[src_tokens.sizes])
        return dataset
コード例 #3
0
    def build_dataset_for_inference(self,
                                    src_tokens,
                                    src_lengths,
                                    constraints=None):
        if constraints is not None:
            raise NotImplementedError(
                "Constrained decoding with the multilingual_translation task is not supported"
            )

        src_data = ListDataset(src_tokens, src_lengths)
        dataset = LanguagePairDataset(src_data, src_lengths,
                                      self.source_dictionary)
        src_langtok_spec, tgt_langtok_spec = self.args.langtoks["main"]
        if self.args.lang_tok_replacing_bos_eos:
            dataset = self.data_manager.alter_dataset_langtok(
                dataset,
                src_eos=self.source_dictionary.eos(),
                src_lang=self.args.source_lang,
                tgt_eos=self.target_dictionary.eos(),
                tgt_lang=self.args.target_lang,
                src_langtok_spec=src_langtok_spec,
                tgt_langtok_spec=tgt_langtok_spec,
            )
        else:
            dataset.src = self.data_manager.src_dataset_tranform_func(
                self.args.source_lang,
                self.args.target_lang,
                dataset=dataset.src,
                spec=src_langtok_spec,
            )
        return dataset
コード例 #4
0
ファイル: test_iterators.py プロジェクト: scheiblr/fairseq
def _get_epoch_batch_itr(ref, bsz, skip_remainder_batch):
    dsz = len(ref)
    indices = range(dsz)
    starts = indices[::bsz]
    batch_sampler = [indices[s:s + bsz] for s in starts]
    dataset = ListDataset(ref)
    itr = iterators.EpochBatchIterator(
        dataset=dataset,
        collate_fn=dataset.collater,
        batch_sampler=batch_sampler,
        skip_remainder_batch=skip_remainder_batch,
    )
    return itr.next_epoch_itr()
コード例 #5
0
    def _set_up_train_dataset(self, split_range) -> torch.utils.data.Dataset:
        new_date_ranges = _date_list_from_arg(self.args.new_data_date_range)
        logger.info(
            f'Setting up training data: {split_range}, {new_date_ranges}')
        new_hive_data = HiveDataset(
            table=self.args.table,
            namespace=self.args.namespace,
            limit=self.args.query_limit,
            date_ranges=new_date_ranges,
            filter_fn=lambda x: _should_include(x[0], split_range),
        )

        desired_total_data_size = self.args.old_to_new_ratio * len(
            new_hive_data)
        desired_old_data_size = (
            1 - (1 / self.args.old_to_new_ratio)) * desired_total_data_size

        old_date_ranges = _date_list_from_arg(self.args.train_date_range)
        old_hive_data = HiveDataset(
            table=self.args.table,
            namespace=self.args.namespace,
            limit=min(self.args.query_limit, desired_old_data_size),
            date_ranges=old_date_ranges,
            filter_fn=lambda x: _should_include(x[0], split_range),
        )

        old_hive_data = old_hive_data[:int(desired_old_data_size)]

        all_data = new_hive_data.data + list(old_hive_data)
        conversations = ConversationDataset(
            dataset=ListDataset(dataset=_shuffle(all_data)),
            dictionary=self.dictionary,
            split_range=split_range,
        )
        logger.info(
            f"Created train dataset of size: {len(conversations)} conversations"
        )

        return conversations
コード例 #6
0
 def build_dataset_for_inference(self, src_tokens, src_lengths):
     src_data = ListDataset(src_tokens, src_lengths)
     dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary)
     src_langtok_spec, tgt_langtok_spec = self.args.langtoks['main']
     if self.args.lang_tok_replacing_bos_eos:
         dataset = self.data_manager.alter_dataset_langtok(
                 dataset,
                 src_eos=self.source_dictionary.eos(),
                 src_lang=self.args.source_lang,
                 tgt_eos=self.target_dictionary.eos(),
                 tgt_lang=self.args.target_lang,
                 src_langtok_spec=src_langtok_spec,
                 tgt_langtok_spec=tgt_langtok_spec,
             )
     else:
         dataset.src = self.data_manager.src_dataset_tranform_func(
             self.args.source_lang,
             self.args.target_lang,
             dataset=dataset.src,
             spec=src_langtok_spec,
             )
     return dataset
コード例 #7
0
ファイル: arc_qa_task.py プロジェクト: mayank97/fairseq
    def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """

        print("Split type --> " + str(split))

        def binarize(s, append_bos=False):
            if self.bpe is not None:
                s = self.bpe.encode(s)
            tokens = self.vocab.encode_line(
                s, append_eos=True, add_if_not_exist=False,
            ).long()
            if append_bos and self.args.init_token is not None:
                tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
            return tokens

        if data_path is None:
            data_path = os.path.join(self.args.data, split + '.jsonl')
        if not os.path.exists(data_path):
            raise FileNotFoundError('Cannot find data: {}'.format(data_path))

        src_tokens = [[] for i in range(self.args.num_classes)]
        src_lengths = [[] for i in range(self.args.num_classes)]
        labels = []

        with open(data_path) as h:
            for line in h:
                example = json.loads(line.strip())
                if 'answerKey' in example:
                    label = ord(example['answerKey']) - ord('A')
                    labels.append(label)
                question = example['question']['stem']
                if(self.args.num_classes != len(example['question']['choices'])):
                    print("Class size = " + str(self.args.num_classes) + ". Length of sample size = " + str(len(example['question']['choices'])))
                assert len(example['question']['choices']) == self.args.num_classes
                # format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>`
                question = 'Q: ' + question
                question_toks = binarize(question, append_bos=True)
                for i, choice in enumerate(example['question']['choices']):
                    src = 'A: ' + choice['text']
                    src_bin = torch.cat([question_toks, binarize(src)])
                    src_tokens[i].append(src_bin)
                    src_lengths[i].append(len(src_bin))
        assert all(len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes))
        assert len(src_tokens[0]) == len(src_lengths[0])
        assert len(labels) == 0 or len(labels) == len(src_tokens[0])

        for i in range(self.args.num_classes):
            src_lengths[i] = np.array(src_lengths[i])
            src_tokens[i] = ListDataset(src_tokens[i], src_lengths[i])
            src_lengths[i] = ListDataset(src_lengths[i])

        dataset = {
            'id': IdDataset(),
            'nsentences': NumSamplesDataset(),
            'ntokens': NumelDataset(src_tokens[0], reduce=True),
        }

        for i in range(self.args.num_classes):
            dataset.update({
                'net_input{}'.format(i + 1): {
                    'src_tokens': RightPadDataset(
                        src_tokens[i],
                        pad_idx=self.source_dictionary.pad(),
                    ),
                    'src_lengths': src_lengths[i],
                }
            })

        if len(labels) > 0:
            dataset.update({'target': RawLabelDataset(labels)})

        dataset = NestedDictionaryDataset(
            dataset,
            sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])],
        )

        with data_utils.numpy_seed(self.args.seed):
            dataset = SortDataset(
                dataset,
                # shuffle
                sort_order=[np.random.permutation(len(dataset))],
            )

        print('| Loaded {} with {} samples'.format(split, len(dataset)))

        self.datasets[split] = dataset
        return self.datasets[split]
コード例 #8
0
ファイル: wsc_task.py プロジェクト: zhuchen03/FreeLB
    def load_dataset(self,
                     split,
                     epoch=0,
                     combine=False,
                     data_path=None,
                     return_only=False,
                     **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        def binarize(s: str, append_eos: bool = False):
            if self.tokenizer is not None:
                s = self.tokenizer.encode(s)
            if self.bpe is not None:
                s = self.bpe.encode(s)
            tokens = self.vocab.encode_line(
                s,
                append_eos=append_eos,
                add_if_not_exist=False,
            ).long()
            if self.args.init_token is not None:
                tokens = torch.cat(
                    [tokens.new([self.args.init_token]), tokens])
            return tokens

        if data_path is None:
            data_path = os.path.join(self.args.data, split + '.jsonl')
        if not os.path.exists(data_path):
            raise FileNotFoundError('Cannot find data: {}'.format(data_path))

        query_tokens = []
        query_masks = []
        query_lengths = []
        candidate_tokens = []
        candidate_masks = []
        candidate_lengths = []
        labels = []

        for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(
                data_path):
            prefix = sentence[:pronoun_span.start].text
            suffix = sentence[pronoun_span.end:].text_with_ws

            # spaCy spans include trailing spaces, but we need to know about
            # leading spaces for the GPT-2 BPE
            leading_space = ' ' if sentence[:pronoun_span.
                                            start].text_with_ws.endswith(
                                                ' ') else ''
            trailing_space = ' ' if pronoun_span.text_with_ws.endswith(
                ' ') else ''

            # get noun phrases, excluding pronouns and anything overlapping with the query
            cand_spans = wsc_utils.filter_noun_chunks(
                wsc_utils.extended_noun_chunks(sentence),
                exclude_pronouns=True,
                exclude_query=query,
                exact_match=False,
            )

            def binarize_with_mask(txt):
                toks = binarize(
                    prefix + leading_space + txt + trailing_space + suffix,
                    append_eos=True,
                )
                mask = torch.zeros_like(toks, dtype=torch.uint8)
                mask_start = len(binarize(prefix))
                mask_size = len(binarize(leading_space + txt))
                mask[mask_start:mask_start + mask_size] = 1
                return toks, mask

            if query is not None:
                query_toks, query_mask = binarize_with_mask(query)
                query_len = len(query_toks)
            else:
                query_toks, query_mask, query_len = None, None, 0

            query_tokens.append(query_toks)
            query_masks.append(query_mask)
            query_lengths.append(query_len)
            cand_toks, cand_masks = [], []
            for cand_span in cand_spans:
                toks, mask = binarize_with_mask(cand_span.text)
                cand_toks.append(toks)
                cand_masks.append(mask)

            # collate candidates
            cand_toks = data_utils.collate_tokens(cand_toks,
                                                  pad_idx=self.vocab.pad())
            cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0)
            assert cand_toks.size() == cand_masks.size()

            candidate_tokens.append(cand_toks)
            candidate_masks.append(cand_masks)
            candidate_lengths.append(cand_toks.size(1))

            labels.append(label)

        query_lengths = np.array(query_lengths)
        query_tokens = ListDataset(query_tokens, query_lengths)
        query_masks = ListDataset(query_masks, query_lengths)

        candidate_lengths = np.array(candidate_lengths)
        candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
        candidate_masks = ListDataset(candidate_masks, candidate_lengths)

        labels = ListDataset(labels, [1] * len(labels))

        dataset = {
            'id': IdDataset(),
            'query_tokens': query_tokens,
            'query_masks': query_masks,
            'candidate_tokens': candidate_tokens,
            'candidate_masks': candidate_masks,
            'labels': labels,
            'nsentences': NumSamplesDataset(),
            'ntokens': NumelDataset(query_tokens, reduce=True),
        }

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[query_lengths],
        )

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(query_tokens))
        dataset = SortDataset(
            nested_dataset,
            # shuffle
            sort_order=[shuffle],
        )

        if return_only:
            return dataset

        self.datasets[split] = dataset
        return self.datasets[split]
コード例 #9
0
ファイル: tacred_task.py プロジェクト: yydxhn/KEPLER
    def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        def getIns(bped,bpeTokens,tokens,L,R):
            resL=0
            tkL=" ".join(tokens[:L])
            bped_tkL=self.bpe.encode(tkL)
            if bped.find(bped_tkL)==0:
                resL=len(bped_tkL.split())
            else:
                tkL+=" "
                bped_tkL=self.bpe.encode(tkL)
                if bped.find(bped_tkL)==0:
                    resL=len(bped_tkL.split())
            resR=0
            tkR=" ".join(tokens[R:])
            bped_tkR=self.bpe.encode(tkR)
            if bped.rfind(bped_tkR)+len(bped_tkR)==len(bped):
                resR=len(bpeTokens)-len(bped_tkR.split())
            else:
                tkR=" "+tkR
                bped_tkR=self.bpe.encode(tkR)
                if bped.rfind(bped_tkR)+len(bped_tkR)==len(bped):
                    resR=len(bpeTokens)-len(bped_tkR.split())
            return resL, resR
        
        def getExample(a,bias):
            s=" ".join(a["token"])
            ss=self.bpe.encode(s)
            sst=ss.split()
            headL=a['h']['pos'][0]
            headR=a['h']['pos'][1]
            hiL, hiR=getIns(ss,sst,a["token"],headL,headR)
            tailL=a['t']['pos'][0]
            tailR=a['t']['pos'][1]
            tiL, tiR=getIns(ss,sst,a["token"],tailL,tailR)
            E1b='1'
            E1e='2'
            E2b='3'
            E2e='4'
            ins=[(hiL, E1b), (hiR, E1e), (tiL, E2b), (tiR, E2e)]
            ins=sorted(ins)
            pE1=0
            pE2=0
            pE1_=0
            pE2_=0
            for i in range(0,4):
                sst.insert(ins[i][0]+i,ins[i][1])
                if ins[i][1]==E1b:
                    pE1=ins[i][0]+i
                elif ins[i][1]==E2b:
                    pE2=ins[i][0]+i
                elif ins[i][1]==E1e:
                    pE1_=ins[i][0]+i
                else:
                    pE2_=ins[i][0]+i
            if pE1_-pE1==1 or pE2_-pE2==1:
                return "???", -1, -1
            else:
                return " ".join(sst), pE1+bias, pE2+bias

        def get_example_bert(item):
            if 'text' in item:
                sentence = item['text']
                is_token = False
            else:
                sentence = item['token']
                is_token = True
            pos_head = item['h']['pos']
            pos_tail = item['t']['pos']

            pos_min = pos_head
            pos_max = pos_tail
            if pos_head[0] > pos_tail[0]:
                pos_min = pos_tail
                pos_max = pos_head
                rev = True
            else:
                rev = False
            
            if not is_token:
                sent0 = self.tokenizer.tokenize(sentence[:pos_min[0]])
                ent0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]])
                sent1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]])
                ent1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]])
                sent2 = self.tokenizer.tokenize(sentence[pos_max[1]:])
            else:
                sent0 = self.tokenizer.tokenize(' '.join(sentence[:pos_min[0]]))
                ent0 = self.tokenizer.tokenize(' '.join(sentence[pos_min[0]:pos_min[1]]))
                sent1 = self.tokenizer.tokenize(' '.join(sentence[pos_min[1]:pos_max[0]]))
                ent1 = self.tokenizer.tokenize(' '.join(sentence[pos_max[0]:pos_max[1]]))
                sent2 = self.tokenizer.tokenize(' '.join(sentence[pos_max[1]:]))

            ent0 = ['[unused0]'] + ent0 + ['[unused1]'] if not rev else ['[unused2]'] + ent0 + ['[unused3]']
            ent1 = ['[unused2]'] + ent1 + ['[unused3]'] if not rev else ['[unused0]'] + ent1 + ['[unused1]']

            re_tokens = ['[CLS]'] + sent0 + ent0 + sent1 + ent1 + sent2 + ['[SEP]']
            pos1 = 1 + len(sent0) if not rev else 1 + len(sent0 + ent0 + sent1)
            pos2 = 1 + len(sent0 + ent0 + sent1) if not rev else 1 + len(sent0)
            #pos1 = min(self.max_length - 1, pos1)
            #pos2 = min(self.max_length - 1, pos2)
            
            indexed_tokens = self.tokenizer.convert_tokens_to_ids(re_tokens)
            avai_len = len(indexed_tokens)

            # Position
            #pos1 = torch.tensor([[pos1]]).long()
            #pos2 = torch.tensor([[pos2]]).long()

            #indexed_tokens = indexed_tokens[:self.max_length]
            indexed_tokens = torch.tensor(indexed_tokens).long()

            return indexed_tokens, pos1, pos2

 
        def binarize(s, append_bos=False):
            #if self.bpe is not None:
            #    s = self.bpe.encode(s)
            tokens = self.vocab.encode_line(
                s, append_eos=True, add_if_not_exist=False,
            ).long()
            if append_bos and self.args.init_token is not None:
                tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
            return tokens

        if data_path is None:
            data_path = os.path.join(self.args.data, split + '.jsonl')
            rel2id_path=os.path.join(self.args.data, "rel2id.json")
        if not os.path.exists(data_path):
            raise FileNotFoundError('Cannot find data: {}'.format(data_path))
        if not os.path.exists(rel2id_path):
            raise FileNotFoundError('Cannot find rel2id: {}'.format(rel2id_path))
        
        rel2id=json.load(open(rel2id_path,"r"))
        labels = []
        src_tokens = []
        src_lengths = []
        src_idx = []
        with open(data_path) as h:
            for line in h:
                example = json.loads(line.strip())
                if 'relation' in example:
                    label = rel2id[example['relation']]
                    labels.append(label)
                #bped=self.bpe.encode(" ".join(example["token"]))
                if getattr(self.args, 'bert', False):
                    src_bin, pE1, pE2 = get_example_bert(example)
                else:
                    bped, pE1, pE2 = getExample(example,1)
                    if pE1==-1:
                        continue
                    src_bin = binarize(bped, append_bos=True)
                src_tokens.append(src_bin)
                src_lengths.append(len(src_bin))
                #pE1=0
                #pE2=0
                src_idx.append([[pE1 for i in range(0,self.args.encoder_embed_dim)], [pE2 for i in range(0,self.args.encoder_embed_dim)]])

        src_lengths = np.array(src_lengths)
        src_tokens = ListDataset(src_tokens, src_lengths)
        src_lengths = ListDataset(src_lengths)
        
        print("src_len", len(src_lengths))
        print("src_tokens", len(src_tokens))
        

        dataset = {
            'id': IdDataset(),
            'net_input':{
                'src_tokens':RightPadDataset(
                    src_tokens,
                    pad_idx=self.source_dictionary.pad()
                ),
                'src_lengths': src_lengths,
            },
            'index': RawLabelDataset(src_idx),
            'target': RawLabelDataset(labels),
            'nsentences': NumSamplesDataset(),
            'ntokens': NumelDataset(src_tokens, reduce=True),
        }
        
        
        dataset = NestedDictionaryDataset(
            dataset,
            sizes=src_tokens.sizes,
        )

        with data_utils.numpy_seed(self.args.seed+epoch):
            dataset = SortDataset(
                dataset,
                # shuffle
                sort_order=[np.random.permutation(len(dataset))],
            )

        print('| Loaded {} with {} samples'.format(split, len(dataset)))

        self.datasets[split] = dataset
        return self.datasets[split]
コード例 #10
0
ファイル: wsc_task.py プロジェクト: StatNLP/ada4asr
 def get_pad_dataset_fn(tokens, length, pad_idx):
     return PadDataset(
         ListDataset(tokens, length),
         pad_idx=pad_idx,
         left_pad=False,
     )
コード例 #11
0
ファイル: wsc_task.py プロジェクト: StatNLP/ada4asr
    def load_dataset(self,
                     split,
                     epoch=1,
                     combine=False,
                     data_path=None,
                     return_only=False,
                     **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        if data_path is None:
            data_path = os.path.join(self.args.data, split + ".jsonl")
        if not os.path.exists(data_path):
            raise FileNotFoundError("Cannot find data: {}".format(data_path))

        query_tokens = []
        query_masks = []
        query_lengths = []
        candidate_tokens = []
        candidate_masks = []
        candidate_lengths = []
        labels = []

        for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(
                data_path):
            prefix = sentence[:pronoun_span.start].text
            suffix = sentence[pronoun_span.end:].text_with_ws

            # spaCy spans include trailing spaces, but we need to know about
            # leading spaces for the GPT-2 BPE
            leading_space = (
                " " if sentence[:pronoun_span.start].text_with_ws.endswith(" ")
                else "")
            trailing_space = " " if pronoun_span.text_with_ws.endswith(
                " ") else ""

            # get noun phrases, excluding pronouns and anything overlapping with the query
            cand_spans = wsc_utils.filter_noun_chunks(
                wsc_utils.extended_noun_chunks(sentence),
                exclude_pronouns=True,
                exclude_query=query,
                exact_match=False,
            )

            if query is not None:
                query_toks, query_mask = self.binarize_with_mask(
                    query, prefix, suffix, leading_space, trailing_space)
                query_len = len(query_toks)
            else:
                query_toks, query_mask, query_len = None, None, 0

            query_tokens.append(query_toks)
            query_masks.append(query_mask)
            query_lengths.append(query_len)

            cand_toks, cand_masks = [], []
            for cand_span in cand_spans:
                toks, mask = self.binarize_with_mask(
                    cand_span.text,
                    prefix,
                    suffix,
                    leading_space,
                    trailing_space,
                )
                cand_toks.append(toks)
                cand_masks.append(mask)

            # collate candidates
            cand_toks = data_utils.collate_tokens(cand_toks,
                                                  pad_idx=self.vocab.pad())
            cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0)
            assert cand_toks.size() == cand_masks.size()

            candidate_tokens.append(cand_toks)
            candidate_masks.append(cand_masks)
            candidate_lengths.append(cand_toks.size(1))

            labels.append(label)

        query_lengths = np.array(query_lengths)
        query_tokens = ListDataset(query_tokens, query_lengths)
        query_masks = ListDataset(query_masks, query_lengths)

        candidate_lengths = np.array(candidate_lengths)
        candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
        candidate_masks = ListDataset(candidate_masks, candidate_lengths)

        labels = ListDataset(labels, [1] * len(labels))

        dataset = {
            "id": IdDataset(),
            "query_tokens": query_tokens,
            "query_masks": query_masks,
            "candidate_tokens": candidate_tokens,
            "candidate_masks": candidate_masks,
            "labels": labels,
            "nsentences": NumSamplesDataset(),
            "ntokens": NumelDataset(query_tokens, reduce=True),
        }

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[query_lengths],
        )

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(query_tokens))
        dataset = SortDataset(
            nested_dataset,
            # shuffle
            sort_order=[shuffle],
        )

        if return_only:
            return dataset

        self.datasets[split] = dataset
        return self.datasets[split]
コード例 #12
0
def lang_pair_dataset(lengths: Sequence[int]) -> LanguagePairDataset:
    tokens = [[i] * l for i, l in enumerate(lengths)]
    return LanguagePairDataset(ListDataset(tokens), lengths, mock_dict())
コード例 #13
0
ファイル: ir_prediction.py プロジェクト: Lynxgsm/OSCAR
    def load_dataset(self, split, epoch=0, combine=False, data_selector=None):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        print('Loading dataset')
        
        data_path = os.path.join(self.args.data)
        dataset_inst = data_utils.load_indexed_dataset(
            os.path.join(data_path, 'insts', split),
            self.instruction_dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        
        dataset_state = data_utils.load_indexed_dataset(
            os.path.join(data_path, 'states', split),
            self.state_dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        
        if dataset_inst is None or dataset_state is None:
            raise FileNotFoundError('Dataset not found: {}'.format(split))
    
        dataset_inst = SeqOfSeqDataset(dataset_inst, self.instruction_dictionary)
        dataset_state = SeqOfSeqDataset(dataset_state, self.state_dictionary)
        dataset_pos = IRPositionDataset(os.path.join(data_path, 'pos', split))
        dataset = IRDataset(dataset_inst, dataset_state, dataset_pos)
        
        block_size = self.args.function_length
    
        dataset = IRPadDataset(
            dataset,
            inst_pad_idx=self.instruction_dictionary.pad(),
            state_pad_idx=self.state_dictionary.pad(),
            inst_mask_idx=self.inst_mask_idx,
            state_mask_idx=self.state_mask_idx,
            inst_cls_idx=self.instruction_dictionary.bos(),
            state_cls_idx=self.state_dictionary.bos(),
            smallbert_insts_per_input=self.args.smallbert_insts_per_group,
            smallbert_states_per_input=self.args.smallbert_insts_per_group,
            max_length=block_size,
            inst_pad_length=32,
            state_pad_length=16,
            pair=True,
        )
        
        labels_str = list(map(json.loads, open(os.path.join(data_path, 'label', split + '.txt'))))
        labels = torch.tensor([x - 1 if isinstance(x, int) else int(x.strip()) - 1 for x in labels_str])
        #function_indices = [torch.tensor(json.loads(x)) for x in open(os.path.join(data_path, 'funcs', split + '.txt'))]
        
        #dataset = IRMultiFunctionDataset(dataset, function_indices, self.args.max_functions_per_program)
    
        print('| loaded {} batches from: {} and {}'.format(len(dataset),
            os.path.join(data_path, 'insts', split), os.path.join(data_path, 'states', split)))

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.labels[split] = SortDataset(RawLabelDataset(labels), sort_order=[shuffle])
        self.datasets[split] = SortDataset(
            NestedDictionaryDataset(
                {
                    'id': IdDataset(),
                    'net_input': {
                        'src': dataset,
                    },
                    'target': RawLabelDataset(labels),
                    'indices': RawLabelDataset(torch.arange(len(dataset))),
                    'subset': ListDataset([split for _ in range(len(dataset))])
                },
                sizes=[dataset.sizes],
            ),
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )
コード例 #14
0
    def load_dataset(self, split, combine=False, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""

        ###encoder 객체 생성
        bpe_encoder = MultiprocessingEncoder(self.args.encoder_json,
                                             self.args.vocab_bpe)
        bpe_encoder.initializer()

        ###preprocess_coqa부르기
        examples, features = get_CoQA_features(self.args,
                                               bpe_encoder,
                                               self.args.init_token,
                                               self.args.separator_token,
                                               self.dictionary.pad(),
                                               split=split)

        self.examples[split] = examples
        self.features[split] = features

        qas_idx = []
        src_tokens = []
        src_lengths = []
        padding_mask = []
        start_pos = []
        end_pos = []
        is_unk = []
        is_yes = []
        is_no = []
        number = []
        option = []

        for feature in features:
            src = torch.IntTensor(feature.input_tokens).long()
            p_mask = torch.IntTensor(feature.p_mask).long()

            src_tokens.append(src)
            src_lengths.append(len(src))
            padding_mask.append(p_mask)
            qas_idx.append(feature.qas_id)

            start_pos.append(feature.start_position)
            end_pos.append(feature.end_position)
            is_unk.append(feature.is_unk)
            is_yes.append(feature.is_yes)
            is_no.append(feature.is_no)
            number.append(feature.number)
            option.append(feature.option)

        src_tokens = ListDataset(src_tokens, src_lengths)
        src_lengths = ListDataset(src_lengths)

        dataset = {
            "id": IdDataset(),
            "nsentences": NumSamplesDataset(),
            "ntokens": NumelDataset(src_tokens, reduce=True),
            "qas_id": RawLabelDataset(qas_idx),
            "net_input": {
                "src_tokens":
                RightPadDataset(src_tokens, pad_idx=self.dictionary.pad()),
                "src_lengths":
                src_lengths,
                "start_position":
                RawLabelDataset(start_pos),
                "p_mask":
                RightPadDataset(padding_mask, pad_idx=self.dictionary.pad()),
            },
            "start_position": RawLabelDataset(start_pos),
            "end_position": RawLabelDataset(end_pos),
            "is_unk": RawLabelDataset(is_unk),
            "is_yes": RawLabelDataset(is_yes),
            "is_no": RawLabelDataset(is_no),
            "number": RawLabelDataset(number),
            "option": RawLabelDataset(option),
        }

        dataset = NestedDictionaryDataset(
            dataset,
            sizes=[np.maximum.reduce([src_tokens.sizes])],
        )

        with data_utils.numpy_seed(self.args.seed):
            dataset = SortDataset(
                dataset,
                sort_order=[np.random.permutation(len(dataset))],
            )

        print("| Loaded {} with {} samples".format(split, len(dataset)))

        self.datasets[split] = dataset
        return self.datasets[split]
コード例 #15
0
    def load_dataset(self,
                     split,
                     epoch=0,
                     combine=False,
                     data_path=None,
                     return_only=False,
                     **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        def binarize(s, append_bos=False):
            if self.bpe is not None:
                s = self.bpe.encode(s)
            tokens = self.vocab.encode_line(
                s,
                append_eos=True,
                add_if_not_exist=False,
            ).long()
            if append_bos and self.args.init_token is not None:
                tokens = torch.cat(
                    [tokens.new([self.args.init_token]), tokens])
            return tokens

        # self.data_path_table={'train_input':os.path.join(self.args.data,'Training  Data','subtaskA_data_all.csv'),\
        #                       'train_answer':os.path.join(self.args.data,'Training  Data','subtaskA_answers_all.csv'),\
        #                       'valid_input':os.path.join(self.args.data,'Trial Data','taskA_trial_data.csv'),\
        #                       'valid_answer':os.path.join(self.args.data,'Trial Data','taskA_trial_answer.csv')\
        #                             }
        # self.data_path_table={'train_input':os.path.join(self.args.data,'trainval','subtaskA_data_all.csv'),\
        #                       'train_answer':os.path.join(self.args.data,'trainval','subtaskA_answers_all.csv'),\
        #                       'valid_input':os.path.join(self.args.data,'Dev Data','subtaskA_dev_data.csv'),\
        #                       'valid_answer':os.path.join(self.args.data,'Dev Data','subtaskA_gold_answers.csv')\
        #                             }
        self.data_path_table={'train_input':os.path.join(self.args.data,'trainvaldev','subtaskA_data_all_plusplus.csv'),\
                              'train_answer':os.path.join(self.args.data,'trainvaldev','subtaskA_answers_all.csv'),\
                              'valid_input':os.path.join(self.args.data,'Dev Data','subtaskA_dev_data_plusplus.csv'),\
                              'valid_answer':os.path.join(self.args.data,'Dev Data','subtaskA_gold_answers.csv')\
                                    }
        # self.data_path_table={'train_input':os.path.join(self.args.data,'subtaskA_data_all.csv'),\
        #                       'train_answer':os.path.join(self.args.data,'subtaskA_answers_all.csv'),\
        #                       'valid_input':os.path.join(self.args.data,'taskA_trial_data.csv'),\
        #                       'valid_answer':os.path.join(self.args.data,'taskA_trial_answer.csv')\
        #                             }
        data_path_input = self.data_path_table[split + '_input']
        data_path_answer = self.data_path_table[split + '_answer']

        if not os.path.exists(data_path_input):
            raise FileNotFoundError(
                'Cannot find data: {}'.format(data_path_input))
        if not os.path.exists(data_path_answer):
            raise FileNotFoundError(
                'Cannot find data: {}'.format(data_path_answer))

        src_tokens = [[] for i in range(self.args.num_classes)]
        src_lengths = [[] for i in range(self.args.num_classes)]
        src_ids = []
        labels = []
        label_ids = []

        with open(data_path_input) as f:
            reader = csv.reader(f)
            for row in islice(reader, 1, None):
                src_ids.append(row[0])
                for i in range(self.args.num_classes):
                    src = row[i + 1]
                    evidence = row[i + 3]
                    if src.isupper():
                        src = src.capitalize()
                    src = src + ' Context: ' + evidence
                    src_bin = binarize(src, append_bos=True)
                    src_tokens[i].append(src_bin)
                    src_lengths[i].append(len(src_bin))

            assert all(
                len(src_tokens[0]) == len(src_tokens[i])
                for i in range(self.args.num_classes))
            assert len(src_tokens[0]) == len(src_lengths[0])

        with open(data_path_answer) as f:
            reader = csv.reader(f)
            for row in reader:
                label_ids.append(row[0])
                label = 1 - int(row[1])
                labels.append(label)

            assert len(labels) == 0 or len(labels) == len(src_tokens[0])
            assert all(src_ids[i] == label_ids[i] for i in range(len(src_ids)))

        for i in range(self.args.num_classes):
            src_lengths[i] = np.array(src_lengths[i])
            src_tokens[i] = ListDataset(src_tokens[i], src_lengths[i])
            src_lengths[i] = ListDataset(src_lengths[i])

        dataset = {
            'id': IdDataset(),
            'nsentences': NumSamplesDataset(),
            'ntokens': NumelDataset(src_tokens[0], reduce=True),
        }

        for i in range(self.args.num_classes):
            dataset.update({
                'net_input{}'.format(i + 1): {
                    'src_tokens':
                    RightPadDataset(
                        src_tokens[i],
                        pad_idx=self.source_dictionary.pad(),
                    ),
                    'src_lengths':
                    src_lengths[i],
                }
            })

        if len(labels) > 0:
            dataset.update({'target': RawLabelDataset(labels)})

        dataset = NestedDictionaryDataset(
            dataset,
            sizes=[
                np.maximum.reduce(
                    [src_token.sizes for src_token in src_tokens])
            ],
        )

        with data_utils.numpy_seed(self.args.seed):
            dataset = SortDataset(
                dataset,
                # shuffle
                sort_order=[np.random.permutation(len(dataset))],
            )

        print('| Loaded {} with {} samples'.format(split, len(dataset)))

        self.datasets[split] = dataset
        return self.datasets[split]