def process(self, data_bundle: DataBundle) -> DataBundle: data_bundle.copy_field(field_name=C.RAW_WORD, new_field_name=C.INPUT, ignore_miss_dataset=True) for name, dataset in data_bundle.datasets.items(): dataset.apply_field(self.copy_func, field_name=C.RAW_WORD, new_field_name=C.INPUT) dataset.add_seq_len(C.INPUT) # 这里没有用Const.INPUT=words而是 raw_words data_bundle.set_input(C.INPUT, C.INPUT_LEN) data_bundle.set_target(C.TARGET) # Const.TARGET ,'target' return data_bundle
def process(self, data_bundle: DataBundle) -> DataBundle: """ 可以处理的DataSet需要包含raw_words列 .. csv-table:: :header: "raw_words" "上海 浦东 开发 与 法制 建设 同步" "新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" "..." :param data_bundle: :return: """ data_bundle.copy_field(Const.RAW_WORD, Const.CHAR_INPUT) if self.replace_num_alpha: data_bundle.apply_field(_find_and_replace_alpha_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) data_bundle.apply_field(_find_and_replace_digit_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) self._tokenize(data_bundle) input_field_names = [Const.CHAR_INPUT] target_field_names = [] for name, dataset in data_bundle.datasets.items(): dataset.apply_field( lambda chars: _word_lens_to_relay(map(len, chars)), field_name=Const.CHAR_INPUT, new_field_name=Const.TARGET) dataset.apply_field( lambda chars: _word_lens_to_start_seg_mask(map(len, chars)), field_name=Const.CHAR_INPUT, new_field_name='start_seg_mask') dataset.apply_field( lambda chars: _word_lens_to_end_seg_mask(map(len, chars)), field_name=Const.CHAR_INPUT, new_field_name='end_seg_mask') dataset.apply_field(lambda chars: list(chain(*chars)), field_name=Const.CHAR_INPUT, new_field_name=Const.CHAR_INPUT) target_field_names.append('start_seg_mask') input_field_names.append('end_seg_mask') if self.bigrams: for name, dataset in data_bundle.datasets.items(): dataset.apply_field( lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], field_name=Const.CHAR_INPUT, new_field_name='bigrams') input_field_names.append('bigrams') _indexize(data_bundle, ['chars', 'bigrams'], []) func = partial(_clip_target, L=self.L) for name, dataset in data_bundle.datasets.items(): res = dataset.apply_field(func, field_name='target') relay_target = [res_i[0] for res_i in res] relay_mask = [res_i[1] for res_i in res] dataset.add_field('relay_target', relay_target, is_input=True, is_target=False, ignore_type=False) dataset.add_field('relay_mask', relay_mask, is_input=True, is_target=False, ignore_type=False) input_field_names.append('relay_target') input_field_names.append('relay_mask') input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names target_fields = [Const.TARGET, Const.INPUT_LEN] + target_field_names for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.CHAR_INPUT) data_bundle.set_input(*input_fields) data_bundle.set_target(*target_fields) return data_bundle