コード例 #1
0
        def wrapper():
            for epoch_index in range(epoch):
                if shuffle:
                    random.shuffle(examples)
                if phase == 'train':
                    self.current_train_epoch = epoch_index
                    features = self.get_features(examples, is_training=True)
                else:
                    features = self.get_features(examples, is_training=False)

                all_dev_batches = []
                for batch_data, total_token_num in batch_reader(
                        features, batch_size, self._in_tokens):
                    batch_data = prepare_batch_data(
                        batch_data,
                        total_token_num,
                        voc_size=-1,
                        pad_id=self.pad_id,
                        cls_id=self.cls_id,
                        sep_id=self.sep_id,
                        mask_id=-1,
                        return_input_mask=True,
                        return_max_len=False,
                        return_num_token=False)
                    if len(all_dev_batches) < dev_count:
                        all_dev_batches.append(batch_data)

                    if len(all_dev_batches) == dev_count:
                        for batch in all_dev_batches:
                            yield batch
                        all_dev_batches = []
コード例 #2
0
 def generate_batch_data(self,
                         batch_data,
                         total_token_num,
                         voc_size=-1,
                         mask_id=-1,
                         return_input_mask=True,
                         return_max_len=False,
                         return_num_token=False,
                         few_shot=False,
                         k=0,
                         n=0,
                         batch_label=[]):
     return prepare_batch_data(batch_data,
                               total_token_num,
                               voc_size=-1,
                               pad_id=self.vocab["[PAD]"],
                               cls_id=self.vocab["[CLS]"],
                               sep_id=self.vocab["[SEP]"],
                               mask_id=-1,
                               return_input_mask=True,
                               return_max_len=False,
                               return_num_token=False,
                               k=k,
                               n=n,
                               batch_label=batch_label)
コード例 #3
0
        def wrapper():
            while True:
                all_dev_batches = []
                for batch_data, total_token_num in batch_reader(
                        self.features, batch_size, self._in_tokens):
                    batch_data = prepare_batch_data(batch_data,
                                                    total_token_num,
                                                    max_seq_len,
                                                    voc_size=-1,
                                                    pad_id=self.pad_id,
                                                    cls_id=self.cls_id,
                                                    sep_id=self.sep_id,
                                                    mask_id=-1,
                                                    return_input_mask=True,
                                                    return_max_len=False,
                                                    return_num_token=False)
                    #if len(all_dev_batches) < dev_count:
                    #    all_dev_batches.append(batch_data)

                    #if len(all_dev_batches) == dev_count:
                    #    for batch in all_dev_batches:
                    #        yield batch
                    #    all_dev_batches = []
                    yield batch_data

                if not is_training:
                    break

                random.shuffle(examples)
                self.features = self.get_features(examples, is_training=True)
コード例 #4
0
ファイル: pretraining.py プロジェクト: zw331/DDParser
        def wrapper():
            def reader():
                for epoch in range(self.epoch):
                    self.current_epoch = epoch + 1
                    if self.shuffle_files:
                        np.random.shuffle(files)
                    for index, file in enumerate(files):
                        file, mask_word_prob = file.strip().split("\t")
                        mask_word = (np.random.random() <
                                     float(mask_word_prob))
                        self.current_file_index = index + 1
                        self.current_file = file
                        if mask_word:
                            self.mask_type = "mask_word"
                        else:
                            self.mask_type = "mask_char"

                        sample_generator = self.read_file(file)
                        if not self.is_test and self.generate_neg_sample:
                            sample_generator = self.mixin_negtive_samples(
                                sample_generator)
                        for sample in sample_generator:
                            if sample is None:
                                continue
                            sample.append(mask_word)
                            yield sample

            def batch_reader(reader, batch_size):
                batch, total_token_num, max_len = [], 0, 0
                for parsed_line in reader():
                    token_ids, sent_ids, pos_ids, label, seg_labels, mask_word = parsed_line
                    max_len = max(max_len, len(token_ids))
                    if (len(batch) + 1) * max_len <= batch_size:
                        batch.append(parsed_line)
                        total_token_num += len(token_ids)
                    else:
                        yield batch, total_token_num
                        batch, total_token_num, max_len = [
                            parsed_line
                        ], len(token_ids), len(token_ids)

                if len(batch) > 0:
                    yield batch, total_token_num

            for batch_data, total_token_num in batch_reader(
                    reader, self.batch_size):
                yield prepare_batch_data(batch_data,
                                         total_token_num,
                                         voc_size=self.voc_size,
                                         pad_id=self.pad_id,
                                         cls_id=self.cls_id,
                                         sep_id=self.sep_id,
                                         mask_id=self.mask_id,
                                         return_input_mask=True,
                                         return_max_len=False,
                                         return_num_token=False)
コード例 #5
0
        def wrapper():
            def reader():
                for epoch in range(self.epoch):
                    self.current_epoch = epoch + 1
                    if self.shuffle_files:
                        np.random.shuffle(files)
                    for index, file in enumerate(files):
                        self.current_file_index = index + 1
                        self.current_file = file
                        sample_generator = self.read_file(file)
                        if not self.is_test and self.generate_neg_sample:
                            sample_generator = self.mixin_negtive_samples(
                                sample_generator)
                        for sample in sample_generator:
                            if sample is None:
                                continue
                            yield sample

            def batch_reader(reader, batch_size, in_tokens):
                batch, total_token_num, max_len = [], 0, 0
                for parsed_line in reader():
                    token_ids, sent_ids, pos_ids, label = parsed_line
                    max_len = max(max_len, len(token_ids))
                    if in_tokens:
                        to_append = (len(batch) + 1) * max_len <= batch_size
                    else:
                        to_append = len(batch) < batch_size
                    if to_append:
                        batch.append(parsed_line)
                        total_token_num += len(token_ids)
                    else:
                        yield batch, total_token_num
                        batch, total_token_num, max_len = [parsed_line], len(
                            token_ids), len(token_ids)

                if len(batch) > 0:
                    yield batch, total_token_num

            for batch_data, total_token_num in batch_reader(
                    reader, self.batch_size, self.in_tokens):
                yield prepare_batch_data(
                    batch_data,
                    total_token_num,
                    voc_size=self.voc_size,
                    pad_id=self.pad_id,
                    cls_id=self.cls_id,
                    sep_id=self.sep_id,
                    mask_id=self.mask_id,
                    return_input_mask=True,
                    return_max_len=False,
                    return_num_token=False)
コード例 #6
0
ファイル: cls.py プロジェクト: zhyq/LARK
 def generate_batch_data(self,
                         batch_data,
                         total_token_num,
                         voc_size=-1,
                         mask_id=-1,
                         return_attn_bias=True,
                         return_max_len=False,
                         return_num_token=False):
     return prepare_batch_data(
         batch_data,
         total_token_num,
         voc_size=-1,
         pad_id=self.vocab["[PAD]"],
         cls_id=self.vocab["[CLS]"],
         sep_id=self.vocab["[SEP]"],
         mask_id=-1,
         return_attn_bias=True,
         return_max_len=False,
         return_num_token=False)
コード例 #7
0
 def generate_batch_data(self,
                         batch_data,
                         max_len,
                         total_token_num,
                         voc_size=-1,
                         mask_id=-1,
                         return_input_mask=True,
                         return_max_len=False,
                         return_num_token=False):
     """generate batch data"""
     return prepare_batch_data(batch_data,
                               max_len,
                               total_token_num,
                               voc_size=-1,
                               pad_id=self.vocab["[PAD]"],
                               cls_id=self.vocab["[CLS]"],
                               sep_id=self.vocab["[SEP]"],
                               mask_id=-1,
                               return_input_mask=True,
                               return_max_len=False,
                               return_num_token=False)
コード例 #8
0
        def wrapper():
            for epoch_index in range(epoch):
                if shuffle:
                    random.shuffle(examples)
                if phase == 'train':
                    self.current_train_epoch = epoch_index
                    features = self.get_features(examples, is_training=True)
                else:
                    features = self.get_features(examples, is_training=False)

                for batch_data, total_token_num in batch_reader(
                        features, batch_size, self._in_tokens):
                    yield prepare_batch_data(
                        batch_data,
                        total_token_num,
                        voc_size=-1,
                        pad_id=self.pad_id,
                        cls_id=self.cls_id,
                        sep_id=self.sep_id,
                        mask_id=-1,
                        return_attn_bias=True,
                        return_max_len=False,
                        return_num_token=False)
コード例 #9
0
        def wrapper():
            def reader():
                for epoch in range(self.epoch):
                    self.current_epoch = epoch + 1
                    files = self.files
                    #during training, data are sliced by trainers
                    if self.shuffle_files:
                        start = epoch * self.total_file
                        end = start + self.total_file
                        files = [file_ for index, file_ in enumerate(self.files[start:end]) \
                            if index % self.trainer_nums == self.trainer_id]

                    for index, file_ in enumerate(files):
                        file_, mask_word_prob = file_.strip().split("\t")
                        mask_word = (np.random.random() <
                                     float(mask_word_prob))
                        self.current_file_index = (index +
                                                   1) * self.trainer_nums
                        self.current_file = file_
                        if mask_word:
                            self.mask_type = "mask_word"
                        else:
                            self.mask_type = "mask_char"

                        sample_generator = self.read_file(file_)
                        if not self.is_test:
                            if self.generate_neg_sample:
                                sample_generator = self.mixin_negtive_samples(
                                    sample_generator)
                            else:
                                #shuffle buffered sample
                                sample_generator = self.shuffle_samples(
                                    sample_generator)

                        for sample in sample_generator:
                            if sample is None:
                                continue
                            sample.append(mask_word)
                            yield sample

            def batch_reader(reader, batch_size):
                batch, total_token_num, max_len = [], 0, 0
                for parsed_line in reader():
                    token_ids, sent_ids, pos_ids, label, seg_labels, mask_word = parsed_line
                    max_len = max(max_len, len(token_ids))
                    if self.in_tokens:
                        to_append = (len(batch) + 1) * max_len <= batch_size
                    else:
                        to_append = len(batch) < batch_size
                    if to_append:
                        batch.append(parsed_line)
                        total_token_num += len(token_ids)
                    else:
                        yield batch, total_token_num
                        batch, total_token_num, max_len = [
                            parsed_line
                        ], len(token_ids), len(token_ids)

                if len(batch) > 0:
                    yield batch, total_token_num

            for batch_data, total_token_num in batch_reader(
                    reader, self.batch_size):
                yield prepare_batch_data(batch_data,
                                         total_token_num,
                                         voc_size=self.voc_size,
                                         pad_id=self.pad_id,
                                         cls_id=self.cls_id,
                                         sep_id=self.sep_id,
                                         mask_id=self.mask_id,
                                         return_input_mask=True,
                                         return_max_len=False,
                                         return_num_token=False)
コード例 #10
0
ファイル: pretraining.py プロジェクト: dingsiyu/ernie-ascend
        def wrapper():
            def reader(task_index):
                files = all_files[task_index]
                for epoch in range(self.epoch):
                    if self.shuffle_files:
                        np.random.shuffle(files)
                    for index, file in enumerate(files):
                        file, mask_word_prob = file.strip().split("\t")
                        mask_word = (np.random.random() <
                                     float(mask_word_prob))

                        if mask_word:
                            self.mask_type = "mask_word"
                        else:
                            self.mask_type = "mask_char"

                        sample_generator = self.read_file(file, task_index)
                        if not self.is_test and self.generate_neg_sample:
                            sample_generator = self.mixin_negtive_samples(
                                sample_generator)

                        for sample in sample_generator:
                            self.current_epoch = epoch + 1
                            self.current_file_index = index + 1
                            self.current_file = file
                            self.total_file = len(files)
                            self.current_epoch = epoch + 1

                            if sample is None:
                                continue
                            sample.append(mask_word)
                            yield sample

            def batch_reader(reader, batch_size):
                batch, total_token_num, max_len = [], 0, 0
                dev_count = 1
                buff = []
                readers = []
                for i in range(len(task_probs)):
                    buff.append(None)
                    readers.append(reader(i))
                task_indices = range(len(task_probs))

                end_times = 0
                while end_times < 50:
                    task_index = np.random.choice(task_indices, p=task_probs)
                    #print("before batch_size: ", batch_size)
                    if self.task_group[task_index]["constart"]:
                        batch_size_fact = batch_size // 2
                    else:
                        batch_size_fact = batch_size
                    #print("after batch_size: ", batch_size_fact)

                    dev_num = 0
                    cur_reader = readers[task_index]

                    while dev_num < dev_count:
                        if buff[task_index] is not None:
                            cur_len = len(buff[task_index][0])
                            max_len = max(max_len, cur_len)
                            batch.append(buff[task_index])
                            total_token_num += cur_len
                            buff[task_index] = None

                        parsed_line = next(cur_reader, None)
                        if parsed_line is None:
                            end_times += 1
                            dev_num += 1
                            if len(batch) > 0:
                                yield batch, total_token_num, task_index, self.task_group[
                                    task_index]["lm_weight"]
                                batch, total_token_num, max_len = [], 0, 0
                            continue

                        end_times = 0
                        cur_len = len(parsed_line[0])
                        max_len = self.max_seq_len  #max(max_len, cur_len)
                        if (len(batch) + 1) * max_len > batch_size_fact:
                            yield batch, total_token_num, task_index, self.task_group[
                                task_index]["lm_weight"]
                            batch, total_token_num, max_len = [], 0, 0
                            dev_num += 1
                            buff[task_index] = parsed_line
                        else:
                            batch.append(parsed_line)
                            total_token_num += cur_len

            for batch_data, total_token_num, task_index, lm_weight in batch_reader(
                    reader, self.batch_size):
                yield prepare_batch_data(batch_data,
                                         total_token_num,
                                         task_index,
                                         lm_weight,
                                         self.max_seq_len,
                                         len(self.task_group),
                                         voc_size=self.voc_size,
                                         pad_id=self.pad_id,
                                         cls_id=self.cls_id,
                                         sep_id=self.sep_id,
                                         mask_id=self.mask_id,
                                         return_input_mask=True,
                                         return_max_len=False,
                                         return_num_token=False)