def _deal_qa_line(self, index, line, oov_token2id): elems = line.strip().split('\t') text = elems[1].strip() seq, seq_exp = self._line2ids(text, self.src_len, oov_token2id) src = helper_fn.pad_with_start_end(seq, self.max_src_len, self.start_id, self.end_id, self.pad_id) src_exp = helper_fn.pad_with_start_end(seq_exp, self.max_src_len, self.start_id, self.end_id, self.pad_id) # used for multi_task. If there is no fact, use src as the answer seq, seq_exp = self._line2ids(text, self.tar_len, oov_token2id) src_tar = helper_fn.pad_with_start(seq, self.max_tar_len, self.start_id, self.pad_id) src_tar_loss = helper_fn.pad_with_end(seq, self.max_tar_len, self.end_id, self.pad_id) text = elems[2].strip() seq, seq_exp = self._line2ids(text, self.tar_len, oov_token2id) tar = helper_fn.pad_with_start(seq, self.max_tar_len, self.start_id, self.pad_id) tar_exp = helper_fn.pad_with_start(seq_exp, self.max_tar_len, self.start_id, self.pad_id) tar_loss = helper_fn.pad_with_end(seq, self.max_tar_len, self.end_id, self.pad_id) tar_loss_exp = helper_fn.pad_with_end(seq_exp, self.max_tar_len, self.end_id, self.pad_id) return src, tar, tar_loss, src_exp, tar_exp, tar_loss_exp, src_tar, src_tar_loss
def _deal_fact_line(self, index, line, oov_token2id): line = line.strip() cur_fact_ids = [] cur_fact_ids_exp = [] fact_tar = None fact_tar_loss = None elems = line.split('\t') no_fact_tag = False # if there is no fact, add pad sequence if elems[1] == config.NO_FACT: cur_fact_ids.append(self.pad_fact_seqs) cur_fact_ids_exp.append(self.pad_fact_seqs) no_fact_tag = True else: for index, text in enumerate(elems[1:]): seq, seq_exp = self._line2ids(text, self.max_fact_len, oov_token2id) new_seq = helper_fn.pad_with_pad(seq, self.max_fact_len, self.pad_id) cur_fact_ids.append(new_seq) new_seq_exp = helper_fn.pad_with_pad(seq_exp, self.max_fact_len, self.pad_id) cur_fact_ids_exp.append(new_seq_exp) if index == 0: seq, seq_exp = self._line2ids(text, self.tar_len, oov_token2id) fact_tar = helper_fn.pad_with_start(seq, self.max_tar_len, self.start_id, self.pad_id) fact_tar_loss = helper_fn.pad_with_end(seq, self.max_tar_len, self.start_id, self.pad_id) # pad fact number cur_fact_ids = cur_fact_ids[:self.args.fact_number] cur_fact_ids_exp = cur_fact_ids_exp[:self.args.fact_number] cur_fact_ids = cur_fact_ids + [self.pad_fact_seqs] * (self.args.fact_number - len(cur_fact_ids)) cur_fact_ids_exp = cur_fact_ids_exp + [self.pad_fact_seqs] * (self.args.fact_number - len(cur_fact_ids_exp)) return no_fact_tag, cur_fact_ids, cur_fact_ids_exp, fact_tar, fact_tar_loss
def _deal_qa(f): source_ids = [] target_ids = [] target_loss_ids = [ ] # Use to calculate loss. Only END sign, dont have START sign for index, line in enumerate(f): elems = line.strip().split('\t') text = elems[1].strip() seq = [ self.src_token_ids.get(token, self.unk_id) for token in text.split() ] #seq = [self.src_token_ids.get(token, self.pad_id) for token in text.split()] seq = seq[:src_len] new_seq = helper_fn.pad_with_start_end(seq, max_src_len, self.start_id, self.end_id, self.pad_id) source_ids.append(new_seq) text = elems[2].strip() seq = [ self.tar_token_ids.get(token, self.unk_id) for token in text.split() ] #seq = [self.tar_token_ids.get(token, self.pad_id) for token in text.split()] seq = seq[:tar_len] new_seq = helper_fn.pad_with_start(seq, max_tar_len, self.start_id, self.pad_id) target_ids.append(new_seq) new_seq = helper_fn.pad_with_end(seq, max_tar_len, self.end_id, self.pad_id) target_loss_ids.append(new_seq) if ((index + 1) % self.args.batch_size == 0): res1 = np.asarray(source_ids) res2 = np.asarray(target_ids) res3 = np.asarray(target_loss_ids) res3 = np.reshape(res3, (res3.shape[0], res3.shape[1], 1)) source_ids = [] target_ids = [] target_loss_ids = [] yield res1, res2, res3 if len(source_ids) != 0: res1 = np.asarray(source_ids) res2 = np.asarray(target_ids) res3 = np.asarray(target_loss_ids) res3 = np.reshape(res3, (res3.shape[0], res3.shape[1], 1)) source_ids = [] target_ids = [] target_loss_ids = [] yield res1, res2, res3
def read_file(self, file_type, max_src_len, max_tar_len, max_fact_len=30, max_conv_len=30, get_fact=False, get_conv=False, get_one_hot=False): ''' :param file_type: This is supposed to be: train, valid, or test :param max_src_len: This is maximem source (question) length :param max_tar_len: This is maximem target (anwser) length :param max_fact_len: This is maximem fact (external knowledge) length, should be the same with source :param max_conv_len: This is maximem conversation (context) length :param get_fact: This is a boolean value to indicate whether read fact file :param get_conv: This is a boolean value to indicate whether read conv file ''' assert (max_src_len > 0) assert (max_tar_len > 0) assert (max_fact_len > 0) assert (max_conv_len > 0) assert file_type == 'train' or file_type == 'valid' or file_type == 'test' print('current file type: %s' % file_type) src_len = max_src_len - config.src_reserved_pos tar_len = max_tar_len - config.tar_reserved_pos if file_type == 'train': qa_path = self.train_set_path conv_path = self.train_conv_path fact_path = self.train_sent_fact_path elif file_type == 'valid': qa_path = self.valid_set_path conv_path = self.valid_conv_path fact_path = self.valid_sent_fact_path elif file_type == 'test': qa_path = self.test_set_path conv_path = self.test_conv_path fact_path = self.test_sent_fact_path # read source and target print(qa_path) f = open(qa_path) indexes = [] source_ids = [] target_ids = [] target_loss_ids = [ ] # Use to calculate loss. Only END sign, dont have START sign for line in f: elems = line.strip().split('\t') if len(elems) < 3: raise ValueError( 'Exceptd input to be 3 dimension, but received %d' % len(elems)) indexes.append(int(elems[0].strip())) text = elems[1].strip() seq = [ self.src_token_ids.get(token, self.unk_id) for token in text.split() ] seq = seq[:src_len] new_seq = helper_fn.pad_with_start_end(seq, max_src_len, self.start_id, self.end_id, self.pad_id) source_ids.append(new_seq) text = elems[2].strip() seq = [ self.tar_token_ids.get(token, self.unk_id) for token in text.split() ] seq = seq[:tar_len] new_seq = helper_fn.pad_with_start(seq, max_tar_len, self.start_id, self.pad_id) target_ids.append(new_seq) new_seq = helper_fn.pad_with_end(seq, max_tar_len, self.end_id, self.pad_id) target_loss_ids.append(new_seq) f.close() if get_one_hot == True: target_one_hot = np.zeros( (len(target_ids), len(target_ids[0]), self.vocab_size), dtype='int32') for i, target in enumerate(target_ids): for t, term_idx in enumerate(target): if t > 0: intaa = 0 target_one_hot[i, t - 1, term_idx] = 1 target_loss_ids = target_one_hot pad_seqs = helper_fn.pad_with_pad([self.pad_id], max_fact_len, self.pad_id) facts_ids = [] if get_fact == True: print(fact_path) with open(fact_path) as f: for index, line in enumerate(f): line = line.strip() fact_ids = [] elems = line.split('\t') # if there is no fact, add pad sequence if elems[1] == config.NO_FACT: fact_ids.append(pad_seqs) else: for text in elems[1:]: seq = [ self.src_token_ids.get(token, self.unk_id) for token in text.split() ] seq = seq[:max_fact_len] new_seq = helper_fn.pad_with_pad( seq, max_fact_len, self.pad_id) fact_ids.append(new_seq) facts_ids.append(fact_ids) # keep facts to be the same number. If there is no so many fact, use pad_id as fact to pad it. facts_ids_tmp = [] for facts in facts_ids: facts = facts[:self.args.fact_number] facts = facts + [pad_seqs ] * (self.args.fact_number - len(facts)) facts_ids_tmp.append(facts) facts_ids = facts_ids_tmp #pad_convs = [self.pad_id] * max_conv_len pad_seqs = helper_fn.pad_with_pad([self.pad_id], max_conv_len, self.pad_id) convs_ids = [] if get_conv == True: print(conv_path) with open(conv_path) as f: for index, line in enumerate(f): line = line.strip() conv_ids = [] elems = line.split('\t') # if there is no context, add pad sequence if elems[1] == config.NO_CONTEXT: conv_ids.append(pad_seqs) else: for text in elems[1:]: seq = [ self.src_token_ids.get(token, self.unk_id) for token in text.split() ] seq = seq[:max_conv_len] new_seq = helper_fn.pad_with_pad( seq, max_conv_len, self.pad_id) conv_ids.append(new_seq) convs_ids.append(conv_ids) # keep conv to be the same number. If there is no so many conv, use pad_id as conv to pad it. convs_ids_tmp = [] for convs in convs_ids: convs = convs[:self.args.conv_number] convs = convs + [pad_seqs ] * (self.args.conv_number - len(convs)) convs_ids_tmp.append(convs) convs_ids = convs_ids_tmp assert (len(source_ids) == len(indexes)) assert (len(source_ids) == len(target_ids)) if get_fact == True: assert (len(source_ids) == len(facts_ids)) if get_conv == True: assert (len(source_ids) == len(convs_ids)) ## [[[ if for Zeyang to output ordered index, not shuffiling. #if get_fact == True and get_conv == True: # indexes, source_ids, target_ids, convs_ids, facts_ids = shuffle(indexes, source_ids, target_ids, convs_ids, facts_ids) #elif get_fact == True: # indexes, source_ids, target_ids, facts_ids = shuffle(indexes, source_ids, target_ids, facts_ids) #else: # indexes, source_ids, target_ids = shuffle(indexes, source_ids, target_ids) ## ]]] return indexes, source_ids, target_ids, target_loss_ids, convs_ids, facts_ids