示例#1
0
    def encode_data_multi_processor(self,
                                    data_generator,
                                    cpu_num_workers,
                                    file_columns,
                                    input_types,
                                    object_inputs,
                                    answer_column_name,
                                    min_sentence_len,
                                    extra_feature,
                                    max_lengths=None,
                                    fixed_lengths=None,
                                    file_format="tsv",
                                    bpe_encoder=None):

        for data in data_generator:
            scheduler = ProcessorsScheduler(cpu_num_workers)
            func_args = (data, file_columns, input_types, object_inputs,
                         answer_column_name, min_sentence_len, extra_feature,
                         max_lengths, fixed_lengths, file_format, bpe_encoder)
            res = scheduler.run_data_parallel(self.encode_data_list, func_args)

            output_data, lengths, target = dict(), dict(), dict()
            cnt_legal, cnt_illegal = 0, 0
            for (index, j) in res:
                # logging.info("collect proccesor %d result"%index)
                tmp_data, tmp_lengths, tmp_target, tmp_cnt_legal, tmp_cnt_illegal = j.get(
                )
                output_data = self._merge_encode_data(output_data, tmp_data)
                lengths = self._merge_encode_lengths(lengths, tmp_lengths)
                target = self._merge_target(target, tmp_target)
                cnt_legal += tmp_cnt_legal
                cnt_illegal += tmp_cnt_illegal
            yield output_data, lengths, target, cnt_legal, cnt_illegal
示例#2
0
    def build_training_multi_processor(self, training_data_generator, cpu_num_workers, file_columns, input_types, answer_column_name, bpe_encoder=None):
        for data in training_data_generator:
            # multi-Processing
            scheduler = ProcessorsScheduler(cpu_num_workers)
            func_args = (data, file_columns, input_types, answer_column_name, bpe_encoder)
            res = scheduler.run_data_parallel(self.build_training_data_list, func_args)
            # aggregate
            docs = dict()           # docs of each type of input
            target_docs = []
            cnt_legal = 0
            cnt_illegal = 0
            for (index, j) in res:
                #logging.info("collect proccesor %d result" % index)
                tmp_docs, tmp_target_docs, tmp_cnt_legal, tmp_cnt_illegal = j.get()
                if len(docs) == 0:
                    docs = tmp_docs
                else:
                    for key, value in tmp_docs.items():
                        docs[key].extend(value)
                if len(target_docs) == 0:
                    target_docs = tmp_target_docs
                else:
                    for single_type in tmp_target_docs:
                        target_docs[single_type].extend(tmp_target_docs[single_type])
                # target_docs.extend(tmp_target_docs)
                cnt_legal += tmp_cnt_legal
                cnt_illegal += tmp_cnt_illegal

            yield docs, target_docs, cnt_legal, cnt_illegal
示例#3
0
    def encode_data_multi_processor(self,
                                    data_list,
                                    cpu_num_workers,
                                    file_columns,
                                    input_types,
                                    object_inputs,
                                    answer_column_name,
                                    min_sentence_len,
                                    extra_feature,
                                    max_lengths=None,
                                    fixed_lengths=None,
                                    file_format="tsv",
                                    bpe_encoder=None):
        def judge_dict(obj):
            return True if isinstance(obj, dict) else False

        scheduler = ProcessorsScheduler(cpu_num_workers)
        func_args = (data_list, file_columns, input_types, object_inputs,
                     answer_column_name, min_sentence_len, extra_feature,
                     max_lengths, fixed_lengths, file_format, bpe_encoder)
        res = scheduler.run_data_parallel(self.encode_data_list, func_args)

        data = dict()
        lengths = dict()
        target = dict()
        cnt_legal = 0
        cnt_illegal = 0

        for (index, j) in res:
            # logging.info("collect proccesor %d result"%index)
            tmp_data, tmp_lengths, tmp_target, tmp_cnt_legal, tmp_cnt_illegal = j.get(
            )

            if len(data) == 0:
                data = tmp_data
            else:
                for branch in tmp_data:
                    for input_type in data[branch]:
                        data[branch][input_type].extend(
                            tmp_data[branch][input_type])
            if len(lengths) == 0:
                lengths = tmp_lengths
            else:
                for branch in tmp_lengths:
                    if judge_dict(tmp_lengths[branch]):
                        for type_branch in tmp_lengths[branch]:
                            lengths[branch][type_branch].extend(
                                tmp_lengths[branch][type_branch])
                    else:
                        lengths[branch].extend(tmp_lengths[branch])
            if not tmp_target:
                target = None
            else:
                if len(target) == 0:
                    target = tmp_target
                else:
                    for single_type in tmp_target:
                        target[single_type].extend(tmp_target[single_type])
            cnt_legal += tmp_cnt_legal
            cnt_illegal += tmp_cnt_illegal

        return data, lengths, target, cnt_legal, cnt_illegal