def _deal_qa_line(self, index, line, oov_token2id):
        elems = line.strip().split('\t')

        text = elems[1].strip()
        seq, seq_exp = self._line2ids(text, self.src_len, oov_token2id)
        src = helper_fn.pad_with_start_end(seq, self.max_src_len,
                                           self.start_id, self.end_id,
                                           self.pad_id)
        src_exp = helper_fn.pad_with_start_end(seq_exp, self.max_src_len,
                                               self.start_id, self.end_id,
                                               self.pad_id)

        # used for multi_task. If there is no fact, use src as the answer
        seq, seq_exp = self._line2ids(text, self.tar_len, oov_token2id)
        src_tar = helper_fn.pad_with_start(seq, self.max_tar_len,
                                           self.start_id, self.pad_id)
        src_tar_loss = helper_fn.pad_with_end(seq, self.max_tar_len,
                                              self.end_id, self.pad_id)

        text = elems[2].strip()
        seq, seq_exp = self._line2ids(text, self.tar_len, oov_token2id)
        tar = helper_fn.pad_with_start(seq, self.max_tar_len, self.start_id,
                                       self.pad_id)
        tar_exp = helper_fn.pad_with_start(seq_exp, self.max_tar_len,
                                           self.start_id, self.pad_id)

        tar_loss = helper_fn.pad_with_end(seq, self.max_tar_len, self.end_id,
                                          self.pad_id)
        tar_loss_exp = helper_fn.pad_with_end(seq_exp, self.max_tar_len,
                                              self.end_id, self.pad_id)

        return src, tar, tar_loss, src_exp, tar_exp, tar_loss_exp, src_tar, src_tar_loss
Example #2
0
    def _deal_fact_line(self, index, line, oov_token2id):
        line = line.strip()
        cur_fact_ids = []
        cur_fact_ids_exp = []
        fact_tar = None
        fact_tar_loss = None
        elems = line.split('\t')
        no_fact_tag = False
        # if there is no fact, add pad sequence
        if elems[1] == config.NO_FACT:
            cur_fact_ids.append(self.pad_fact_seqs)
            cur_fact_ids_exp.append(self.pad_fact_seqs)
            no_fact_tag = True
        else:
            for index, text in enumerate(elems[1:]):
                seq, seq_exp = self._line2ids(text, self.max_fact_len, oov_token2id)
                new_seq = helper_fn.pad_with_pad(seq, self.max_fact_len, self.pad_id)
                cur_fact_ids.append(new_seq)
                new_seq_exp = helper_fn.pad_with_pad(seq_exp, self.max_fact_len, self.pad_id)
                cur_fact_ids_exp.append(new_seq_exp)
                if index == 0:
                    seq, seq_exp = self._line2ids(text, self.tar_len, oov_token2id)
                    fact_tar = helper_fn.pad_with_start(seq, self.max_tar_len, self.start_id, self.pad_id)
                    fact_tar_loss = helper_fn.pad_with_end(seq, self.max_tar_len, self.start_id, self.pad_id)

        # pad fact number
        cur_fact_ids = cur_fact_ids[:self.args.fact_number]
        cur_fact_ids_exp = cur_fact_ids_exp[:self.args.fact_number]

        cur_fact_ids = cur_fact_ids + [self.pad_fact_seqs] * (self.args.fact_number - len(cur_fact_ids))
        cur_fact_ids_exp = cur_fact_ids_exp + [self.pad_fact_seqs] * (self.args.fact_number - len(cur_fact_ids_exp))

        return no_fact_tag, cur_fact_ids, cur_fact_ids_exp, fact_tar, fact_tar_loss
        def _deal_qa(f):
            source_ids = []
            target_ids = []
            target_loss_ids = [
            ]  # Use to calculate loss. Only END sign, dont have START sign
            for index, line in enumerate(f):
                elems = line.strip().split('\t')
                text = elems[1].strip()
                seq = [
                    self.src_token_ids.get(token, self.unk_id)
                    for token in text.split()
                ]
                #seq = [self.src_token_ids.get(token, self.pad_id) for token in text.split()]
                seq = seq[:src_len]
                new_seq = helper_fn.pad_with_start_end(seq, max_src_len,
                                                       self.start_id,
                                                       self.end_id,
                                                       self.pad_id)
                source_ids.append(new_seq)

                text = elems[2].strip()
                seq = [
                    self.tar_token_ids.get(token, self.unk_id)
                    for token in text.split()
                ]
                #seq = [self.tar_token_ids.get(token, self.pad_id) for token in text.split()]
                seq = seq[:tar_len]
                new_seq = helper_fn.pad_with_start(seq, max_tar_len,
                                                   self.start_id, self.pad_id)
                target_ids.append(new_seq)
                new_seq = helper_fn.pad_with_end(seq, max_tar_len, self.end_id,
                                                 self.pad_id)
                target_loss_ids.append(new_seq)

                if ((index + 1) % self.args.batch_size == 0):
                    res1 = np.asarray(source_ids)
                    res2 = np.asarray(target_ids)
                    res3 = np.asarray(target_loss_ids)
                    res3 = np.reshape(res3, (res3.shape[0], res3.shape[1], 1))
                    source_ids = []
                    target_ids = []
                    target_loss_ids = []
                    yield res1, res2, res3
            if len(source_ids) != 0:
                res1 = np.asarray(source_ids)
                res2 = np.asarray(target_ids)
                res3 = np.asarray(target_loss_ids)
                res3 = np.reshape(res3, (res3.shape[0], res3.shape[1], 1))
                source_ids = []
                target_ids = []
                target_loss_ids = []
                yield res1, res2, res3
    def read_file(self,
                  file_type,
                  max_src_len,
                  max_tar_len,
                  max_fact_len=30,
                  max_conv_len=30,
                  get_fact=False,
                  get_conv=False,
                  get_one_hot=False):
        '''
        :param file_type: This is supposed to be: train, valid, or test
        :param max_src_len: This is maximem source (question) length
        :param max_tar_len: This is maximem target (anwser) length
        :param max_fact_len: This is maximem fact (external knowledge) length, should be the same with source
        :param max_conv_len: This is maximem conversation (context) length
        :param get_fact: This is a boolean value to indicate whether read fact file
        :param get_conv: This is a boolean value to indicate whether read conv file
        '''

        assert (max_src_len > 0)
        assert (max_tar_len > 0)
        assert (max_fact_len > 0)
        assert (max_conv_len > 0)
        assert file_type == 'train' or file_type == 'valid' or file_type == 'test'
        print('current file type: %s' % file_type)

        src_len = max_src_len - config.src_reserved_pos
        tar_len = max_tar_len - config.tar_reserved_pos

        if file_type == 'train':
            qa_path = self.train_set_path
            conv_path = self.train_conv_path
            fact_path = self.train_sent_fact_path
        elif file_type == 'valid':
            qa_path = self.valid_set_path
            conv_path = self.valid_conv_path
            fact_path = self.valid_sent_fact_path
        elif file_type == 'test':
            qa_path = self.test_set_path
            conv_path = self.test_conv_path
            fact_path = self.test_sent_fact_path

        # read source and target
        print(qa_path)
        f = open(qa_path)
        indexes = []
        source_ids = []
        target_ids = []
        target_loss_ids = [
        ]  # Use to calculate loss. Only END sign, dont have START sign
        for line in f:
            elems = line.strip().split('\t')
            if len(elems) < 3:
                raise ValueError(
                    'Exceptd input to be 3 dimension, but received %d' %
                    len(elems))

            indexes.append(int(elems[0].strip()))
            text = elems[1].strip()
            seq = [
                self.src_token_ids.get(token, self.unk_id)
                for token in text.split()
            ]
            seq = seq[:src_len]
            new_seq = helper_fn.pad_with_start_end(seq, max_src_len,
                                                   self.start_id, self.end_id,
                                                   self.pad_id)
            source_ids.append(new_seq)

            text = elems[2].strip()
            seq = [
                self.tar_token_ids.get(token, self.unk_id)
                for token in text.split()
            ]
            seq = seq[:tar_len]
            new_seq = helper_fn.pad_with_start(seq, max_tar_len, self.start_id,
                                               self.pad_id)
            target_ids.append(new_seq)
            new_seq = helper_fn.pad_with_end(seq, max_tar_len, self.end_id,
                                             self.pad_id)
            target_loss_ids.append(new_seq)
        f.close()
        if get_one_hot == True:
            target_one_hot = np.zeros(
                (len(target_ids), len(target_ids[0]), self.vocab_size),
                dtype='int32')
            for i, target in enumerate(target_ids):
                for t, term_idx in enumerate(target):
                    if t > 0:
                        intaa = 0
                        target_one_hot[i, t - 1, term_idx] = 1
            target_loss_ids = target_one_hot

        pad_seqs = helper_fn.pad_with_pad([self.pad_id], max_fact_len,
                                          self.pad_id)
        facts_ids = []
        if get_fact == True:
            print(fact_path)
            with open(fact_path) as f:
                for index, line in enumerate(f):
                    line = line.strip()
                    fact_ids = []
                    elems = line.split('\t')
                    # if there is no fact, add pad sequence
                    if elems[1] == config.NO_FACT:
                        fact_ids.append(pad_seqs)
                    else:
                        for text in elems[1:]:
                            seq = [
                                self.src_token_ids.get(token, self.unk_id)
                                for token in text.split()
                            ]
                            seq = seq[:max_fact_len]
                            new_seq = helper_fn.pad_with_pad(
                                seq, max_fact_len, self.pad_id)
                            fact_ids.append(new_seq)
                    facts_ids.append(fact_ids)
            # keep facts to be the same number. If there is no so many fact, use pad_id as fact to pad it.
            facts_ids_tmp = []
            for facts in facts_ids:
                facts = facts[:self.args.fact_number]
                facts = facts + [pad_seqs
                                 ] * (self.args.fact_number - len(facts))
                facts_ids_tmp.append(facts)
            facts_ids = facts_ids_tmp

        #pad_convs = [self.pad_id] * max_conv_len
        pad_seqs = helper_fn.pad_with_pad([self.pad_id], max_conv_len,
                                          self.pad_id)
        convs_ids = []
        if get_conv == True:
            print(conv_path)
            with open(conv_path) as f:
                for index, line in enumerate(f):
                    line = line.strip()
                    conv_ids = []
                    elems = line.split('\t')
                    # if there is no context, add pad sequence
                    if elems[1] == config.NO_CONTEXT:
                        conv_ids.append(pad_seqs)
                    else:
                        for text in elems[1:]:
                            seq = [
                                self.src_token_ids.get(token, self.unk_id)
                                for token in text.split()
                            ]
                            seq = seq[:max_conv_len]
                            new_seq = helper_fn.pad_with_pad(
                                seq, max_conv_len, self.pad_id)
                            conv_ids.append(new_seq)
                    convs_ids.append(conv_ids)
            # keep conv to be the same number. If there is no so many conv, use pad_id as conv to pad it.
            convs_ids_tmp = []
            for convs in convs_ids:
                convs = convs[:self.args.conv_number]
                convs = convs + [pad_seqs
                                 ] * (self.args.conv_number - len(convs))
                convs_ids_tmp.append(convs)
            convs_ids = convs_ids_tmp

        assert (len(source_ids) == len(indexes))
        assert (len(source_ids) == len(target_ids))
        if get_fact == True:
            assert (len(source_ids) == len(facts_ids))
        if get_conv == True:
            assert (len(source_ids) == len(convs_ids))

        ## [[[ if for Zeyang to output ordered index, not shuffiling.
        #if get_fact == True and get_conv == True:
        #    indexes, source_ids, target_ids, convs_ids, facts_ids = shuffle(indexes, source_ids, target_ids, convs_ids, facts_ids)
        #elif get_fact == True:
        #    indexes, source_ids, target_ids, facts_ids = shuffle(indexes, source_ids, target_ids, facts_ids)
        #else:
        #    indexes, source_ids, target_ids = shuffle(indexes, source_ids, target_ids)
        ## ]]]

        return indexes, source_ids, target_ids, target_loss_ids, convs_ids, facts_ids