def _convert_example_to_feature(self, example): dialogue_context = " [SEP] ".join(example.context_turns + example.current_turn) input_id = self.src_tokenizer.encode(dialogue_context, add_special_tokens=False) max_length = self.max_seq_length - 2 if len(input_id) > max_length: gap = len(input_id) - max_length input_id = input_id[gap:] input_id = ([self.src_tokenizer.cls_token_id] + input_id + [self.src_tokenizer.sep_token_id]) segment_id = [0] * len(input_id) target_ids = [] gating_id = [] if not example.label: example.label = [] state = convert_state_dict(example.label) for slot in self.slot_meta: value = state.get(slot, "none") target_id = self.trg_tokenizer.encode( value, add_special_tokens=False) + [self.trg_tokenizer.sep_token_id] target_ids.append(target_id) gating_id.append(self.gating2id.get(value, self.gating2id["ptr"])) target_ids = self.pad_ids(target_ids, self.trg_tokenizer.pad_token_id) return OpenVocabDSTFeature(example.guid, input_id, segment_id, gating_id, target_ids)
def _convert_example_to_feature(self, example): guid = example[0].guid.rsplit("-", 1)[0] # dialogue_idx turns = [] token_types = [] labels = [] num_turn = None for turn in example[: self.max_turn_length]: assert len(turn.current_turn) == 2 uttrs = [] for segment_idx, uttr in enumerate(turn.current_turn): token = self.src_tokenizer.encode(uttr, add_special_tokens=False) uttrs.append(token) _truncate_seq_pair(uttrs[0], uttrs[1], self.max_seq_length - 3) tokens = ( [self.src_tokenizer.cls_token_id] + uttrs[0] + [self.src_tokenizer.sep_token_id] + uttrs[1] + [self.src_tokenizer.sep_token_id] ) token_type = [0] * (len(uttrs[0]) + 2) + [1] * (len(uttrs[1]) + 1) if len(tokens) < self.max_seq_length: gap = self.max_seq_length - len(tokens) tokens.extend([self.src_tokenizer.pad_token_id] * gap) token_type.extend([0] * gap) turns.append(tokens) token_types.append(token_type) label = [] if turn.label: slot_dict = convert_state_dict(turn.label) else: slot_dict = {} for slot_type in self.slot_meta: value = slot_dict.get(slot_type, "none") # TODO # raise Exception('label_idx를 ontology에서 꺼내오는 코드를 작성하세요!') if value in self.ontology[slot_type]: label_idx = self.ontology[slot_type].index(value) else: label_idx = self.ontology[slot_type].index("none") label.append(label_idx) labels.append(label) num_turn = len(turns) if len(turns) < self.max_turn_length: gap = self.max_turn_length - len(turns) for _ in range(gap): dummy_turn = [self.src_tokenizer.pad_token_id] * self.max_seq_length turns.append(dummy_turn) token_types.append(dummy_turn) dummy_label = [-1] * len(self.slot_meta) labels.append(dummy_label) return OntologyDSTFeature( guid=guid, input_ids=turns, segment_ids=token_types, num_turn=num_turn, target_ids=labels, )
def _convert_example_to_feature( self, example: DSTInputExample ) -> OpenVocabDSTFeature: """List[DSTInputExample]를 feature로 변형하는 데 사용되는 nested 함수. 다음과 같이 사용 Examples: processor = TRADEPreprocessor(slot_meta, tokenizer) features = processor.convert_examples_to_features(examples) Args: example (DSTInputExample) Returns: [OpenVocabDSTFeature]: feature 데이터 """ # XLM-Robert 토크나이저 케이스 추가 if self.src_tokenizer.special_tokens_map["sep_token"] == "</s>": dialogue_context = " <s> ".join( example.context_turns + example.current_turn ) else: dialogue_context = " [SEP] ".join( example.context_turns + example.current_turn ) input_id = self.src_tokenizer.encode(dialogue_context, add_special_tokens=False) max_length = self.max_seq_length - 2 if len(input_id) > max_length: gap = len(input_id) - max_length input_id = input_id[gap:] input_id = ( [self.src_tokenizer.cls_token_id] + input_id + [self.src_tokenizer.sep_token_id] ) segment_id = [0] * len(input_id) target_ids = [] gating_id = [] if not example.label: example.label = [] state = convert_state_dict(example.label) for slot in self.slot_meta: value = state.get(slot, "none") target_id = self.trg_tokenizer.encode(value, add_special_tokens=False) + [ self.trg_tokenizer.sep_token_id ] target_ids.append(target_id) gating_id.append(self.gating2id.get(value, self.gating2id["ptr"])) target_ids = self.pad_ids(target_ids, self.trg_tokenizer.pad_token_id) return OpenVocabDSTFeature( example.guid, input_id, segment_id, gating_id, target_ids )