def get_input_ids(self, examples: List[dict]) -> Dict[str, List[list]]: """Convert utterance string to id based on BertTokenizer. Args: examples: a list of example, one example contain dialogue id, turn id, input id, token type id, ground truth label Returns: examples dict, same type data are placed in a same list. e.g. dialogue id of all examples are placed into examples_dict['dial_ids'] """ examples_dict = defaultdict(list) for example in examples: input_ids, token_type_ids = str2id( self.tokenizer, example["sys_utter"], example["usr_utter"], example["source"], ) examples_dict["dial_ids"].append(example["dial_id"]) examples_dict["turn_ids"].append(example["turn_id"]) examples_dict["input_ids"].append(input_ids) examples_dict["token_type_ids"].append(token_type_ids) examples_dict["labels"].append(example["label"]) return examples_dict
def preprocess( self, belief_state: Dict[str, dict], cur_domain: str, history: List[tuple] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Preprocess raw dialogue data to bert inputs. Args: belief_state: see `xbot/util/state.py` cur_domain: current domain history: dialogue history, [('usr', 'xxx'), ('sys', 'xxx'), ...] Returns: bert inputs, contain input_ids, token_type_ids, attention_mask """ sys_utter = "对话开始" usr_utter = "对话开始" if len(history) > 0: usr_utter = history[-1][1] if len(history) > 2: sys_utter = history[-2][1] source = self.get_source(belief_state, cur_domain) input_ids, token_type_ids = str2id(self.tokenizer, sys_utter, usr_utter, source) attention_mask, input_ids, token_type_ids = pad([input_ids], [token_type_ids]) return input_ids, token_type_ids, attention_mask
def get_input_ids(self, examples: List[dict]) -> Dict[str, list]: """Convert input tokens to ids and construct data dict. Args: examples: a list of {'dial_id': xxx, 'turn_id': xxx, 'source': xxx, ...} Returns: examples_dict, {'dial_ids': [1,2,3,4....], ....} """ examples_dict = defaultdict(list) for example in examples: input_ids, token_type_ids = str2id(self.tokenizer, example['sys_utter'], example['usr_utter'], example['source']) examples_dict['dial_ids'].append(example['dial_id']) examples_dict['turn_ids'].append(example['turn_id']) examples_dict['input_ids'].append(input_ids) examples_dict['token_type_ids'].append(token_type_ids) examples_dict['labels'].append(example['label']) return examples_dict