Пример #1
0
    def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path):
        assert os.path.isfile(file_path)
        logger.info("Creating features from dataset file at %s", file_path)

        cache_fn = f'{file_path}.cache'
        if args.cache_data and os.path.isfile(
                cache_fn) and not args.overwrite_cache:
            logger.info("Loading cached data from %s", cache_fn)
            self.examples = torch.load(cache_fn)
        else:
            self.examples = []
            with open(file_path, encoding="utf-8") as f:
                for line in f.readlines():
                    if len(line) > 0 and not line.isspace() and len(
                            line.split(' ||| ')) == 2:
                        try:
                            src, tgt = line.split(' ||| ')
                            if src.rstrip() == '' or tgt.rstrip() == '':
                                continue
                        except:
                            logger.info("Skipping instance %s", line)
                            continue
                        sent_src, sent_tgt = src.strip().split(), tgt.strip(
                        ).split()
                        token_src, token_tgt = [
                            tokenizer.tokenize(word) for word in sent_src
                        ], [tokenizer.tokenize(word) for word in sent_tgt]
                        wid_src, wid_tgt = [
                            tokenizer.convert_tokens_to_ids(x)
                            for x in token_src
                        ], [
                            tokenizer.convert_tokens_to_ids(x)
                            for x in token_tgt
                        ]

                        ids_src, ids_tgt = tokenizer.prepare_for_model(
                            list(itertools.chain(*wid_src)),
                            return_tensors='pt',
                            max_length=tokenizer.max_len
                        )['input_ids'], tokenizer.prepare_for_model(
                            list(itertools.chain(*wid_tgt)),
                            return_tensors='pt',
                            max_length=tokenizer.max_len)['input_ids']

                        bpe2word_map_src = []
                        for i, word_list in enumerate(token_src):
                            bpe2word_map_src += [i for x in word_list]
                        bpe2word_map_tgt = []
                        for i, word_list in enumerate(token_tgt):
                            bpe2word_map_tgt += [i for x in word_list]

                        self.examples.append(
                            (ids_src, ids_tgt, bpe2word_map_src,
                             bpe2word_map_tgt))

            if args.cache_data:
                logger.info("Saving cached data to %s", cache_fn)
                torch.save(self.examples, cache_fn)
Пример #2
0
    def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path):
        assert os.path.isfile(file_path)
        print('Loading the dataset...')
        self.examples = []
        with open(file_path, encoding="utf-8") as f:
            for idx, line in enumerate(f.readlines()):
                if len(line) == 0 or line.isspace() or not len(
                        line.split(' ||| ')) == 2:
                    raise ValueError(
                        f'Line {idx+1} is not in the correct format!')

                src, tgt = line.split(' ||| ')
                if src.rstrip() == '' or tgt.rstrip() == '':
                    raise ValueError(
                        f'Line {idx+1} is not in the correct format!')

                sent_src, sent_tgt = src.strip().split(), tgt.strip().split()
                token_src, token_tgt = [
                    tokenizer.tokenize(word) for word in sent_src
                ], [tokenizer.tokenize(word) for word in sent_tgt]
                wid_src, wid_tgt = [
                    tokenizer.convert_tokens_to_ids(x) for x in token_src
                ], [tokenizer.convert_tokens_to_ids(x) for x in token_tgt]

                ids_src, ids_tgt = tokenizer.prepare_for_model(
                    list(itertools.chain(*wid_src)),
                    return_tensors='pt',
                    max_length=tokenizer.max_len
                )['input_ids'], tokenizer.prepare_for_model(
                    list(itertools.chain(*wid_tgt)),
                    return_tensors='pt',
                    max_length=tokenizer.max_len)['input_ids']
                if len(ids_src[0]) == 2 or len(ids_tgt[0]) == 2:
                    raise ValueError(
                        f'Line {idx+1} is not in the correct format!')

                bpe2word_map_src = []
                for i, word_list in enumerate(token_src):
                    bpe2word_map_src += [i for x in word_list]
                bpe2word_map_tgt = []
                for i, word_list in enumerate(token_tgt):
                    bpe2word_map_tgt += [i for x in word_list]

                self.examples.append((ids_src[0], ids_tgt[0], bpe2word_map_src,
                                      bpe2word_map_tgt))