Пример #1
0
    def build_instances(self, all_documents):
        instances = []
        instances_num = len(all_documents) // self.seq_length
        for i in range(instances_num):
            src = all_documents[i * self.seq_length:(i + 1) * self.seq_length]
            seg_pos = [len(src)]

            if not self.dynamic_masking:
                src, tgt = mask_seq(src, self.tokenizer,
                                    self.whole_word_masking, self.span_masking,
                                    self.span_geo_prob, self.span_max_length)
                instance = (src, tgt, seg_pos)
            else:
                instance = (src, seg_pos)

            instances.append(instance)

        src = all_documents[instances_num * self.seq_length:]
        seg_pos = [len(src)]

        while len(src) != self.seq_length:
            src.append(self.vocab.get(PAD_TOKEN))

        if not self.dynamic_masking:
            src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking,
                                self.span_masking, self.span_geo_prob,
                                self.span_max_length)
            instance = (src, tgt, seg_pos)
        else:
            instance = (src, seg_pos)

        instances.append(instance)
        return instances
Пример #2
0
    def create_ins_from_doc(self, document):
        max_num_tokens = self.seq_length - 3
        target_seq_length = max_num_tokens
        if random.random() < self.short_seq_prob:
            target_seq_length = random.randint(2, max_num_tokens)
        instances = []
        current_chunk = []
        current_length = 0
        i = 0
        while i < len(document):
            segment = document[i]
            current_chunk.append(segment)
            current_length += len(segment)
            if i == len(document) - 1 or current_length >= target_seq_length:
                if current_chunk:
                    a_end = 1
                    if len(current_chunk) >= 2:
                        a_end = random.randint(1, len(current_chunk) - 1)

                    tokens_a = []
                    for j in range(a_end):
                        tokens_a.extend(current_chunk[j])

                    tokens_b = []
                    is_wrong_order = 0
                    for j in range(a_end, len(current_chunk)):
                        tokens_b.extend(current_chunk[j])

                    if random.random() < 0.5:
                        is_wrong_order = 1
                        tmp = tokens_a
                        tokens_a = tokens_b
                        tokens_b = tmp

                    truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)

                    src = []
                    src.append(self.vocab.get(CLS_TOKEN))
                    src.extend(tokens_a)
                    src.append(self.vocab.get(SEP_TOKEN))
                    seg_pos = [len(src)]
                    src.extend(tokens_b)
                    src.append(self.vocab.get(SEP_TOKEN))
                    seg_pos.append(len(src))

                    while len(src) != self.seq_length:
                        src.append(self.vocab.get(PAD_TOKEN))

                    if not self.dynamic_masking:
                        src, tgt_mlm = mask_seq(src, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length)
                        instance = (src, tgt_mlm, is_wrong_order, seg_pos)
                    else:
                        instance = (src, is_wrong_order, seg_pos)

                    instances.append(instance)
                current_chunk = []
                current_length = 0
            i += 1
        return instances
Пример #3
0
    def __iter__(self):
        while True:
            while self._empty():
                self._fill_buf()
            if self.start + self.batch_size >= self.end:
                instances = self.buffer[self.start:]
            else:
                instances = self.buffer[self.start:self.start +
                                        self.batch_size]

            self.start += self.batch_size

            src = []
            tgt_in = []
            tgt_out = []
            seg = []

            for _, ins in enumerate(instances):
                src_single, pad_num = ins[0]
                for _ in range(pad_num):
                    src_single.append(self.vocab.get(PAD_TOKEN))
                tgt_single, pad_num = ins[1]
                for _ in range(pad_num):
                    tgt_single.append(self.vocab.get(PAD_TOKEN))

                src_single, _ = mask_seq(src_single, self.tokenizer,
                                         self.whole_word_masking,
                                         self.span_masking, self.span_geo_prob,
                                         self.span_max_length)
                seg_pos = ins[2][0]
                tgt_in.append(tgt_single[:-1])
                tgt_out.append(tgt_single[1:])

                MASK_ID = self.vocab.get(MASK_TOKEN)

                src_with_span_mask = []
                for token_id in src_single:
                    if token_id == MASK_ID:
                        if len(src_with_span_mask
                               ) > 0 and src_with_span_mask[-1] == MASK_ID:
                            seg_pos -= 1
                        else:
                            src_with_span_mask.append(MASK_ID)
                    else:
                        src_with_span_mask.append(token_id)

                while len(src_with_span_mask) < len(src_single):
                    src_with_span_mask.append(self.vocab.get(PAD_TOKEN))

                seg.append([1] * seg_pos + [0] * (len(src_single) - seg_pos))
                src.append(src_with_span_mask)


            yield torch.LongTensor(src), \
                torch.LongTensor(tgt_in), \
                torch.LongTensor(tgt_out), \
                torch.LongTensor(seg)
Пример #4
0
    def __iter__(self):
        while True:
            while self._empty():
                self._fill_buf()
            if self.start + self.batch_size >= self.end:
                instances = self.buffer[self.start:]
            else:
                instances = self.buffer[self.start:self.start +
                                        self.batch_size]

            self.start += self.batch_size

            src = []
            tgt_mlm = []
            is_next = []
            seg = []

            masked_words_num = 0

            for ins in instances:
                src_single, pad_num = ins[0]
                for _ in range(pad_num):
                    src_single.append(self.vocab.get(PAD_TOKEN))

                if len(ins) == 4:
                    src.append(src_single)
                    masked_words_num += len(ins[1])
                    tgt_mlm.append([0] * len(src_single))
                    for mask in ins[1]:
                        tgt_mlm[-1][mask[0]] = mask[1]
                    is_next.append(ins[2])
                    seg.append([1] * ins[3][0] + [2] *
                               (ins[3][1] - ins[3][0]) + [0] * pad_num)
                else:
                    src_single, tgt_mlm_single = mask_seq(
                        src_single, self.tokenizer, self.whole_word_masking,
                        self.span_masking, self.span_geo_prob,
                        self.span_max_length)
                    masked_words_num += len(tgt_mlm_single)
                    src.append(src_single)
                    tgt_mlm.append([0] * len(src_single))
                    for mask in tgt_mlm_single:
                        tgt_mlm[-1][mask[0]] = mask[1]
                    is_next.append(ins[1])
                    seg.append([1] * ins[2][0] + [2] *
                               (ins[2][1] - ins[2][0]) + [0] * pad_num)

            if masked_words_num == 0:
                continue

            yield torch.LongTensor(src), \
                torch.LongTensor(tgt_mlm), \
                torch.LongTensor(is_next), \
                torch.LongTensor(seg)
Пример #5
0
    def build_instances(self, all_documents):
        instances = []
        instances_num = len(all_documents) // self.seq_length
        for i in range(instances_num):
            src = all_documents[i * self.seq_length:(i + 1) * self.seq_length]
            seg_pos = [len(src)]

            if not self.dynamic_masking:
                src, tgt = mask_seq(src, self.tokenizer,
                                    self.whole_word_masking, self.span_masking,
                                    self.span_geo_prob, self.span_max_length)
                instance = ((src, 0), tgt, seg_pos)
            else:
                instance = ((src, 0), seg_pos)

            instances.append(instance)

        src = all_documents[instances_num * self.seq_length:]

        if len(src) == 0:
            return instances

        seg_pos = [len(src)]

        pad_num = self.seq_length - len(src)

        if not self.dynamic_masking:
            src, tgt = mask_seq(src, self.tokenizer, self.whole_word_masking,
                                self.span_masking, self.span_geo_prob,
                                self.span_max_length)
            instance = ((src, pad_num), tgt, seg_pos)
        else:
            instance = ((src, pad_num), seg_pos)

        instances.append(instance)
        return instances
Пример #6
0
    def __iter__(self):
        while True:
            while self._empty():
                self._fill_buf()
            if self.start + self.batch_size >= self.end:
                instances = self.buffer[self.start:]
            else:
                instances = self.buffer[self.start:self.start +
                                        self.batch_size]

            self.start += self.batch_size

            src = []
            tgt = []
            seg = []

            masked_words_num = 0

            for ins in instances:
                if len(ins) == 3:
                    src.append(ins[0])
                    masked_words_num += len(ins[1])
                    tgt.append([0] * len(ins[0]))
                    for mask in ins[1]:
                        tgt[-1][mask[0]] = mask[1]
                    seg.append([1] * ins[2][0] + [0] *
                               (len(ins[0]) - ins[2][0]))
                else:
                    src_single, tgt_single = mask_seq(ins[0], self.tokenizer,
                                                      self.whole_word_masking,
                                                      self.span_masking,
                                                      self.span_geo_prob,
                                                      self.span_max_length)
                    masked_words_num += len(tgt_single)
                    src.append(src_single)
                    tgt.append([0] * len(ins[0]))
                    for mask in tgt_single:
                        tgt[-1][mask[0]] = mask[1]
                    seg.append([1] * ins[1][0] + [0] *
                               (len(ins[0]) - ins[1][0]))

            if masked_words_num == 0:
                continue

            yield torch.LongTensor(src), \
                torch.LongTensor(tgt), \
                torch.LongTensor(seg)
Пример #7
0
    def __iter__(self):
        while True:
            while self._empty():
                self._fill_buf()
            if self.start + self.batch_size >= self.end:
                instances = self.buffer[self.start:]
            else:
                instances = self.buffer[self.start:self.start +
                                        self.batch_size]

            self.start += self.batch_size

            src = []
            tgt_in = []
            tgt_out = []
            seg = []

            tgt_seq_length = 0

            for _, ins in enumerate(instances):
                src_single, pad_num = ins[0]
                for _ in range(pad_num):
                    src_single.append(self.vocab.get(PAD_TOKEN))

                if len(ins) == 3:
                    tgt_single = ins[1]
                    seg.append([1] * ins[2][0] + [0] * pad_num)
                else:
                    src_single, tgt_single = mask_seq(src_single,
                                                      self.tokenizer,
                                                      self.whole_word_masking,
                                                      self.span_masking,
                                                      self.span_geo_prob,
                                                      self.span_max_length)
                    seg.append([1] * ins[1][0] + [0] * pad_num)

                MASK_ID = self.vocab.get(MASK_TOKEN)
                SENTINEL_ID = self.vocab.get(SENTINEL_TOKEN)
                PAD_ID = self.vocab.get(PAD_TOKEN)

                for src_index, _ in tgt_single:
                    if src_single[src_index] != MASK_ID:
                        src_single[src_index] = MASK_ID

                tgt_in_single = [self.vocab.get(CLS_TOKEN)]
                mask_index = 0
                src_with_sentinel = []
                for token_id in src_single:
                    if token_id == MASK_ID:
                        if len(src_with_sentinel) > 0 and src_with_sentinel[
                                -1] == (SENTINEL_ID - 1):
                            pass
                        else:
                            src_with_sentinel.append(SENTINEL_ID)
                            tgt_in_single.append(SENTINEL_ID)
                            if SENTINEL_ID < len(self.vocab) - 1:
                                SENTINEL_ID += 1
                        tgt_in_single.append(tgt_single[mask_index][1])
                        mask_index += 1
                    else:
                        src_with_sentinel.append(token_id)
                tgt_in_single.append(SENTINEL_ID)
                tgt_in_single.append(self.vocab.get(SEP_TOKEN))

                while len(src_with_sentinel) < len(src_single):
                    src_with_sentinel.append(PAD_ID)

                if len(tgt_in_single) > tgt_seq_length:
                    tgt_seq_length = len(tgt_in_single)

                src.append(src_with_sentinel)
                tgt_in.append(tgt_in_single)
                tgt_out.append(tgt_in[-1][1:] + [PAD_ID])

            for i in range(len(tgt_in)):
                while len(tgt_in[i]) != tgt_seq_length:
                    tgt_in[i].append(PAD_ID)
                    tgt_out[i].append(PAD_ID)

            yield torch.LongTensor(src), \
                torch.LongTensor(tgt_in), \
                torch.LongTensor(tgt_out), \
                torch.LongTensor(seg)
Пример #8
0
    def create_ins_from_doc(self, all_documents, document_index):
        document = all_documents[document_index]
        max_num_tokens = self.seq_length - 3
        target_seq_length = max_num_tokens
        if random.random() < self.short_seq_prob:
            target_seq_length = random.randint(2, max_num_tokens)
        instances = []
        current_chunk = []
        current_length = 0
        i = 0
        while i < len(document):
            segment = document[i]
            current_chunk.append(segment)
            current_length += len(segment)
            if i == len(document) - 1 or current_length >= target_seq_length:
                if current_chunk:
                    a_end = 1
                    if len(current_chunk) >= 2:
                        a_end = random.randint(1, len(current_chunk) - 1)

                    tokens_a = []
                    for j in range(a_end):
                        tokens_a.extend(current_chunk[j])

                    tokens_b = []
                    is_random_next = 0

                    if len(current_chunk) == 1 or random.random() < 0.5:
                        is_random_next = 1
                        target_b_length = target_seq_length - len(tokens_a)

                        for _ in range(10):
                            random_document_index = random.randint(
                                0,
                                len(all_documents) - 1)
                            if random_document_index != document_index:
                                break

                        random_document = all_documents[random_document_index]
                        random_start = random.randint(0,
                                                      len(random_document) - 1)
                        for j in range(random_start, len(random_document)):
                            tokens_b.extend(random_document[j])
                            if len(tokens_b) >= target_b_length:
                                break

                        num_unused_segments = len(current_chunk) - a_end
                        i -= num_unused_segments

                    else:
                        is_random_next = 0
                        for j in range(a_end, len(current_chunk)):
                            tokens_b.extend(current_chunk[j])

                    truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)

                    src = []
                    src.append(self.vocab.get(CLS_TOKEN))
                    src.extend(tokens_a)
                    src.append(self.vocab.get(SEP_TOKEN))
                    seg_pos = [len(src)]
                    src.extend(tokens_b)
                    src.append(self.vocab.get(SEP_TOKEN))
                    seg_pos.append(len(src))

                    while len(src) != self.seq_length:
                        src.append(self.vocab.get(PAD_TOKEN))

                    if not self.dynamic_masking:
                        src, tgt_mlm = mask_seq(src, self.tokenizer,
                                                self.whole_word_masking,
                                                self.span_masking,
                                                self.span_geo_prob,
                                                self.span_max_length)
                        instance = (src, tgt_mlm, is_random_next, seg_pos)
                    else:
                        instance = (src, is_random_next, seg_pos)

                    instances.append(instance)
                current_chunk = []
                current_length = 0
            i += 1
        return instances
Пример #9
0
    def worker(self, proc_id, start, end):
        print("Worker %d is building dataset ... " % proc_id)
        set_seed(self.seed)
        dataset_writer = open("dataset-tmp-" + str(proc_id) + ".pt", "wb")
        pos = 0
        with open(self.corpus_path, mode="r", encoding="utf-8") as f:
            while pos < start:
                f.readline()
                pos += 1
            while True:
                line = f.readline()
                pos += 1

                line = line.strip().split('\t')
                if len(line) == 2:
                    label = int(line[0])
                    text = line[1]
                    src = [self.vocab.get(CLS_TOKEN)
                           ] + self.tokenizer.convert_tokens_to_ids(
                               self.tokenizer.tokenize(text)) + [
                                   self.vocab.get(SEP_TOKEN)
                               ]
                    tgt_cls = label
                    seg_pos = [len(src)]
                elif len(line) == 3:  # For sentence pair input.
                    label = int(line[0])
                    text_a, text_b = line[1], line[2]

                    src_a = self.tokenizer.convert_tokens_to_ids(
                        self.tokenizer.tokenize(text_a))
                    src_a = [self.vocab.get(CLS_TOKEN)
                             ] + src_a + [self.vocab.get(SEP_TOKEN)]
                    src_b = self.tokenizer.convert_tokens_to_ids(
                        self.tokenizer.tokenize(text_b))
                    src_b = src_b + [self.vocab.get(SEP_TOKEN)]

                    src = src_a + src_b
                    tgt_cls = label
                    seg_pos = [len(src_a)] + [len(src_b)]
                else:
                    if pos >= end:
                        break
                    continue

                if len(src) >= self.seq_length:
                    pad_num = 0
                    src = (src[:self.seq_length], pad_num)
                    if len(seg_pos) == 1:
                        seg_pos = [self.seq_length]
                    else:
                        if len(src_a) >= self.seq_length:
                            seg_pos = [self.seq_length]
                        else:
                            seg_pos = [len(src_a)
                                       ] + [self.seq_length - len(src_a)]
                else:
                    pad_num = self.seq_length - len(src)
                    src = (src, pad_num)

                if not self.dynamic_masking:
                    src_single, pad_num = src
                    src_single, tgt_mlm = mask_seq(src_single, self.tokenizer,
                                                   self.whole_word_masking,
                                                   self.span_masking,
                                                   self.span_geo_prob,
                                                   self.span_max_length)
                    src = (src_single, pad_num)
                    instance = (src, tgt_mlm, tgt_cls, seg_pos)
                else:
                    instance = (src, tgt_cls, seg_pos)

                pickle.dump(instance, dataset_writer)

                if pos >= end:
                    break

        dataset_writer.close()