def __init__(
        self,
        df: pandas.core.frame.DataFrame,
        labels: pandas.core.frame.DataFrame,
        tokenizer: BertTokenizerFast,
    ):
        """[Dataloader to generate shuffle batch of data]

        Args:
            df (pandas.core.frame.DataFrame): [Train text]
            labels (pandas.core.frame.DataFrame): [Train labels (jobs)]
            tokenizer (BertTokenizerFast): [Text tokenizer for bert]
        """

        self.labels = labels.tolist()
        self.encodings = tokenizer.batch_encode_plus(
            df["description"].to_list(),
            max_length=170,
            padding=True,
            truncation=True)
        self.Id = df["Id"].to_list()
        self.gender = df["gender"].to_list()
Esempio n. 2
0
class KorquadDataset(Dataset):
    def __init__(self, train=True):
        if train:
            path = "/data/KorQuAD_v1.0_train.json"
            db_name = "korquad_train.qas"
        else:
            path = "/data/KorQuAD_v1.0_dev.json"
            db_name = "korquad_dev.qas"
        self.tokenizer = BertTokenizerFast("wiki-vocab.txt")

        data = json.load(open(path, encoding="utf-8"))["data"]

        self.qas = []
        if not os.path.exists(db_name):
            with open(db_name, "wb") as f:
                self.mecab = Mecab()
                ignored_cnt = 0
                for paragraphs in tqdm(data):
                    paragraphs = paragraphs["paragraphs"]
                    for paragraph in paragraphs:
                        _context = paragraph["context"]
                        for qa in paragraph["qas"]:
                            question = qa["question"]
                            answer = qa["answers"][0]["text"]
                            (
                                input_ids,
                                token_type_ids,
                                start_token_pos,
                                end_token_pos,
                            ) = self.extract_features(
                                _context,
                                question,
                                answer,
                                qa["answers"][0]["answer_start"],
                            )
                            if len(input_ids) > 512:
                                if not train:
                                    pickle.dump(
                                        (
                                            input_ids,
                                            token_type_ids,
                                            start_token_pos,
                                            end_token_pos,
                                        ),
                                        f,
                                    )
                            else:
                                if train:
                                    pickle.dump(
                                        (
                                            input_ids,
                                            token_type_ids,
                                            start_token_pos,
                                            end_token_pos,
                                        ),
                                        f,
                                    )
                                else:
                                    pickle.dump(
                                        (
                                            input_ids,
                                            token_type_ids,
                                            start_token_pos,
                                            end_token_pos,
                                        ),
                                        f,
                                    )

        with open(db_name, "rb") as f:
            while True:
                try:
                    data = pickle.load(f)
                    self.qas.append(data)
                except EOFError:
                    break
            print(len(self.qas))

    @property
    def token_num(self):
        return self.tokenizer.vocab_size

    def __len__(self):
        return len(self.qas)

    def encode(self, line):
        converted_results = map(
            lambda x: x[1:-1],
            self.tokenizer.batch_encode_plus(line)["input_ids"])
        return [2, *chain.from_iterable(converted_results), 3]

    def decode(self, token_ids):
        decode_str = self.tokenizer.decode(token_ids, skip_special_tokens=True)
        return decode_str

    def __getitem__(self, idx):
        return self.qas[idx]

    def extract_features(self, context, question, answer, start_char_pos):
        if answer is None:
            # use encode_plus function in tokenizer
            tokenized_q = self.tokenize(question)
            tokenized_c = self.tokenize(context)
            input_ids = [*tokenized_q, *tokenized_c[1:]]
            token_type_ids = [
                *[0 for _ in tokenized_q], *[1 for _ in tokenized_c[1:]]
            ]
            start_token_pos: int = None
            end_token_pos: int = None
        else:
            # Split sentences using len(answer) and start_char_pos
            context_front = context[:start_char_pos]
            context_back = context[start_char_pos + len(answer):]
            q_ids = self.tokenize(question)
            f_ids = self.tokenize(context_front)
            a_ids = self.tokenize(answer)
            b_ids = self.tokenize(context_back)

            # For processing subwords
            if context_front != "" and context_front[-1] != " ":
                a_ids = [a_ids[0], a_ids[1], *a_ids[2:]]
            if context_back != "" and context_back[0] != " ":
                b_ids = [b_ids[0], b_ids[1], *b_ids[2:]]

            # Manually generate input_ids, token_type_ids and start/end_token_pos (carefully remove [CLS] and [SEP])
            input_ids = [*q_ids, *f_ids[1:-1], *a_ids[1:-1], *b_ids[1:]]
            token_type_ids = [
                *[0 for _ in q_ids],
                *[1 for _ in f_ids[1:-1]],
                *[1 for _ in a_ids[1:-1]],
                *[1 for _ in b_ids[1:]],
            ]
            start_token_pos = len(q_ids) + (len(f_ids) - 2)
            end_token_pos = len(q_ids) + (len(f_ids) - 2) + (len(a_ids) -
                                                             2) - 1

        return input_ids, token_type_ids, start_token_pos, end_token_pos

    def tokenize(self, sentence):
        if len(sentence) == 0:
            return [2, 3]
        return self.encode(
            [j2hcj(h2j(word)) for word in self.mecab.morphs(sentence)])

    def collate_fn(self, samples):
        input_ids, token_type_ids, start_pos, end_pos = zip(*samples)
        attention_mask = [[1] * len(input_id) for input_id in input_ids]

        input_ids = pad_sequence(
            [torch.Tensor(input_id).to(torch.long) for input_id in input_ids],
            padding_value=0,
            batch_first=True,
        )
        token_type_ids = pad_sequence(
            [
                torch.Tensor(token_type_id).to(torch.long)
                for token_type_id in token_type_ids
            ],
            padding_value=1,
            batch_first=True,
        )
        attention_mask = pad_sequence(
            [torch.Tensor(mask).to(torch.long) for mask in attention_mask],
            padding_value=0,
            batch_first=True,
        )

        start_pos = torch.Tensor(start_pos).to(torch.long)
        end_pos = torch.Tensor(end_pos).to(torch.long)

        return input_ids, attention_mask, token_type_ids, start_pos, end_pos
Esempio n. 3
0
class WikiDataset(Dataset):
    def __init__(self, train=True):
        if train:
            path = ("/data/data_train.txt", "/data/pos_train.txt")
        else:
            path = ("/data/data_val.txt", "/data/pos_val.txt")
        self.tokenizer = BertTokenizerFast("wiki-vocab.txt")
        self.paragraphs = [[]]
        self.pos_labels = set([])

        valid = True
        with open(path[0], encoding="utf-8") as f_data:
            with open(path[1], encoding="utf-8") as f_pos:
                for d, p in tqdm(zip(f_data, f_pos), desc="load_data"):
                    if len(d.strip()) == 0:
                        if len(self.paragraphs[-1]) > 0:
                            self.paragraphs.append([])
                        else:
                            valid = True
                    elif valid:
                        _d, _p = d.strip().split(), p.strip().split()
                        if len(_d) != len(_p) or len(_p) > 256:
                            valid = False
                            self.paragraphs[-1] = []
                        else:
                            assert len(_d) == len(_p), f"{len(_d)} {len(_p)}"
                            self.paragraphs[-1].append((_d, _p))
                            self.pos_labels |= set(_p)

        print(len(self.paragraphs))

        if train:
            self.pos_labels_to_ids = {}
            for i, pos_label in enumerate(sorted(self.pos_labels)):
                self.pos_labels_to_ids[pos_label] = i + 1
        else:
            with open('./pretrain.wiki.dict') as f:
                self.pos_labels_to_ids = eval(f.read())

            i = len(self.pos_labels_to_ids)
            for _, pos_label in enumerate(sorted(self.pos_labels)):
                if pos_label not in self.pos_labels_to_ids:
                    self.pos_labels_to_ids[pos_label] = i
                    i += 1

    @property
    def token_num(self):
        return self.tokenizer.vocab_size

    @property
    def pos_num(self):
        return len(self.pos_labels) + 1

    def __len__(self):
        return len(self.paragraphs) - 1

    def encode_line(self, d, p):
        converted_results = map(
            lambda x: (x[0][1:-1], [self.pos_labels_to_ids[x[1]]] *
                       (len(x[0]) - 2)),
            zip(self.tokenizer.batch_encode_plus(d)["input_ids"], p),
        )
        token_ids, pos_ids = zip(*converted_results)
        return (
            list(chain.from_iterable(token_ids))[:500],
            list(chain.from_iterable(pos_ids))[:500],
        )

    def __getitem__(self, idx):
        token_ids, pos_ids = zip(
            *list(map(lambda x: self.encode_line(*x), self.paragraphs[idx])))
        return token_ids, pos_ids