Пример #1
0
class SrcTgtCorpus(Corpus):
    def __init__(self,
                 data_dir,
                 data_prefix,
                 min_freq=0,
                 max_vocab_size=None,
                 min_len=0,
                 max_len=100,
                 embed_file=None,
                 share_vocab=False):
        super(SrcTgtCorpus, self).__init__(data_dir=data_dir,
                                           data_prefix=data_prefix,
                                           min_freq=min_freq,
                                           max_vocab_size=max_vocab_size)
        self.min_len = min_len
        self.max_len = max_len
        self.share_vocab = share_vocab

        self.SRC = TextField(tokenize_fn=tokenize, embed_file=embed_file)
        if self.share_vocab:
            self.TGT = self.SRC
        else:
            self.TGT = TextField(tokenize_fn=tokenize, embed_file=embed_file)

        self.fields = {'src': self.SRC, 'tgt': self.TGT}

        def src_filter_pred(src):

            return min_len <= len(self.SRC.tokenize_fn(src)) <= max_len

        def tgt_filter_pred(tgt):

            return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len

        self.filter_pred = lambda ex: src_filter_pred(ex[
            'src']) and tgt_filter_pred(ex['tgt'])

    def read_data(self, data_file, data_type="train"):

        data = []
        filtered = 0
        with open(data_file, "r", encoding="utf-8") as f:
            for line in f:
                src, tgt = line.strip().split('\t')[:2]
                data.append({'src': src, 'tgt': tgt})

        filtered_num = len(data)
        if self.filter_pred is not None:
            data = [ex for ex in data if self.filter_pred(ex)]
        filtered_num -= len(data)
        print("Read {} {} examples ({} filtered)".format(
            len(data), data_type.upper(), filtered_num))
        return data
Пример #2
0
class SrcTgtCorpus(Corpus):
    """
    SrcTgtCorpus
    """
    def __init__(self, config):
        super(SrcTgtCorpus, self).__init__(config)
        self.min_len = config.min_len
        self.max_len = config.max_len
        self.embed_file = config.embed_file
        self.share_vocab = config.share_vocab

        self.SRC = TextField(tokenize_fn=tokenize, embed_file=self.embed_file)
        if self.share_vocab:
            self.TGT = self.SRC
        else:
            self.TGT = TextField(tokenize_fn=tokenize,
                                 embed_file=self.embed_file)
        self.OUTPUT = self.TGT
        self.fields = {'src': self.SRC, 'tgt': self.TGT, 'output': self.OUTPUT}

        def src_filter_pred(src):
            return self.min_len <= len(
                self.SRC.tokenize_fn(src)) <= self.max_len

        def tgt_filter_pred(tgt):
            return self.min_len <= len(
                self.TGT.tokenize_fn(tgt)) <= self.max_len

        # self.filter_pred = lambda ex: src_filter_pred(ex['src']) and tgt_filter_pred(ex['tgt'])
        self.filter_pred = None

    def read_data(self, data_file, data_type="train"):
        data = []
        filtered = 0

        with open(data_file, "r", encoding="utf-8") as f:
            for line in f:
                src = json.loads(line.strip())['src']
                tgt = json.loads(line.strip())['tgt']
                data.append({'src': src, 'tgt': tgt})

        filtered_num = len(data)
        if self.filter_pred is not None:
            data = [ex for ex in data if self.filter_pred(ex)]
        filtered_num -= len(data)
        print("Read {} {} examples ({} filtered)".format(
            len(data), data_type.upper(), filtered_num))
        return data
Пример #3
0
class TopicGuide2Corpus(Corpus):
    """
    CueCorpus
    """
    def __init__(self, config):
        super(TopicGuide2Corpus, self).__init__(config)
        self.min_len = config.min_len
        self.max_len = config.max_len
        self.share_vocab = config.share_vocab
        self.embed_file = config.embed_file
        self.topic_words_num = config.topic_words_num
        self.topic_vocab_file = config.topic_vocab_file

        self.SRC = TextField(tokenize_fn=tokenize, embed_file=self.embed_file)

        self.TGT = TextField(tokenize_fn=tokenize, embed_file=self.embed_file)

        self.LABEL = NumberField(dtype=int)
        self.OUTPUT = TextField(tokenize_fn=tokenize,
                                embed_file=self.embed_file)

        self.TOPIC = FixedField(fix_file=self.topic_vocab_file,
                                embed_file=self.embed_file)

        self.fields = {
            'src': self.SRC,
            'tgt': self.TGT,
            'output': self.OUTPUT,
            'topic_src_label': self.LABEL,
            'topic_tgt_label': self.LABEL,
            'topic': self.TOPIC,
        }

        def src_filter_pred(src):
            return min_len <= len(self.SRC.tokenize_fn(src)) <= max_len

        def tgt_filter_pred(tgt):
            return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len

        # self.filter_pred = lambda ex: src_filter_pred(ex['src']) and tgt_filter_pred(ex['tgt'])
        self.filter_pred = None

    def read_data(self, data_file, data_type="train"):
        """
        read_data
        """
        data = []
        with open(data_file, "r", encoding="utf-8") as f:
            for line in f:
                line = json.loads(line.strip())
                src = line['src']
                tgt = line['tgt']
                topic_src_label = line['topic_src_label']
                topic_tgt_label = line['topic_tgt_label']

                data.append({
                    'src': src,
                    'tgt': tgt,
                    'output': tgt,
                    'topic_src_label': topic_src_label,
                    'topic_tgt_label': topic_tgt_label,
                    'topic': src,
                })

        filtered_num = len(data)
        if self.filter_pred is not None:
            data = [ex for ex in data if self.filter_pred(ex)]
        filtered_num -= len(data)
        print("Read {} {} examples ({} filtered)".format(
            len(data), data_type.upper(), filtered_num))
        return data

    def build_vocab(self, data):
        """
        Args
        ----
        data: ``List[Dict]``
        """
        field_data_dict = {}
        for name in data[0].keys():
            field = self.fields.get(name)
            if isinstance(field, TextField):
                xs = [x[name] for x in data]
                if field not in field_data_dict:
                    field_data_dict[field] = xs
                else:
                    field_data_dict[field] += xs

        vocab_dict = {}
        field_dict = {}
        for name, field in self.fields.items():
            if name == 'topic':
                field.build_vocab()
                field_dict[name] = field
                continue

            if field in field_data_dict:
                print("Building vocabulary of field {} ...".format(
                    name.upper()))
                if field.vocab_size == 0:
                    if name != 'output':
                        field.build_vocab(field_data_dict[field],
                                          min_freq=self.min_freq,
                                          max_size=self.max_vocab_size)
                    field_dict[name] = field

        field_dict['output'].add_with_other_field(field_dict['tgt'])
        field_dict['output'].add_with_other_field(field_dict['topic'])
        if self.embed_file is not None:
            field_dict['output'].embeddings = field_dict[
                'output'].build_word_embeddings(self.embed_file)

        for name, field in field_dict.items():
            vocab_dict[name] = field.dump_vocab()

        return vocab_dict

    def load_vocab(self, prepared_vocab_file):
        super().load_vocab(prepared_vocab_file)

        self.topic_bow_vocab_size = self.TOPIC.vocab_size
        self.Dataset = lambda x: WithBowDataset(
            data=x, bow_vocab_size=self.topic_bow_vocab_size)
Пример #4
0
class KnowledgeCorpus(Corpus):
    def __init__(self,
                 data_dir,
                 data_prefix,
                 min_freq=0,
                 max_vocab_size=None,
                 min_len=0,
                 max_len=100,
                 embed_file=None,
                 share_vocab=False,
                 with_label=False):
        super(KnowledgeCorpus, self).__init__(data_dir=data_dir,
                                              data_prefix=data_prefix,
                                              min_freq=min_freq,
                                              max_vocab_size=max_vocab_size)
        self.min_len = min_len
        self.max_len = max_len
        self.share_vocab = share_vocab
        self.with_label = with_label

        self.SRC = TextField(tokenize_fn=tokenize,
                             embed_file=embed_file)
        if self.share_vocab:
            self.TGT = self.SRC
            self.CUE = self.SRC
        else:
            self.TGT = TextField(tokenize_fn=tokenize,
                                 embed_file=embed_file)
            self.CUE = TextField(tokenize_fn=tokenize,
                                 embed_file=embed_file)

        if self.with_label:
            self.INDEX = NumberField()
            self.fields = {'src': self.SRC, 'tgt': self.TGT, 'cue': self.CUE, 'index': self.INDEX}
        else:
            self.fields = {'src': self.SRC, 'tgt': self.TGT, 'cue': self.CUE}

        def src_filter_pred(src):
            return min_len <= len(self.SRC.tokenize_fn(src)) <= max_len

        def tgt_filter_pred(tgt):
            return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len

        self.filter_pred = lambda ex: src_filter_pred(
            ex['src']) and tgt_filter_pred(ex['tgt'])

    def read_data(self, data_file, data_type="train"):
        data = []
        with open(data_file, "r", encoding="utf-8") as f:
            for line in f:
                if self.with_label:
                    src, tgt, knowledge, label = line.strip().split('\t')[:4]
                    filter_knowledge = []
                    for sent in knowledge.split(''):
                        filter_knowledge.append(' '.join(sent.split()[:self.max_len]))
                    data.append({'src': src, 'tgt': tgt, 'cue': filter_knowledge, 'index': label})
                else:
                    src, tgt, knowledge = line.strip().split('\t')[:3]
                    filter_knowledge = []
                    for sent in knowledge.split(''):
                        filter_knowledge.append(' '.join(sent.split()[:self.max_len]))
                    data.append({'src': src, 'tgt': tgt, 'cue':filter_knowledge})

        filtered_num = len(data)
        if self.filter_pred is not None:
            data = [ex for ex in data if self.filter_pred(ex)]
        filtered_num -= len(data)
        print(
            f"Read {len(data)} {data_type.upper()} examples ({filtered_num} filtered)")
        return data
Пример #5
0
class KnowledgeCorpus(Corpus):
    """
    KnowledgeCorpus
    """
    def __init__(self,
                 data_dir,
                 data_prefix,
                 min_freq=0,
                 max_vocab_size=None,
                 vocab_file=None,
                 min_len=0,
                 max_len=100,
                 embed_file=None,
                 share_vocab=False,
                 with_label=False):
        super(KnowledgeCorpus, self).__init__(data_dir=data_dir,
                                              data_prefix=data_prefix,
                                              min_freq=min_freq,
                                              max_vocab_size=max_vocab_size,
                                              vocab_file=vocab_file)
        self.min_len = min_len
        self.max_len = max_len
        self.share_vocab = share_vocab
        self.with_label = with_label

        self.SRC = TextField(tokenize_fn=tokenize, embed_file=embed_file)
        if self.share_vocab:
            self.TGT = self.SRC
            self.CUE = self.SRC
        else:
            self.TGT = TextField(tokenize_fn=tokenize, embed_file=embed_file)
            self.CUE = TextField(tokenize_fn=tokenize, embed_file=embed_file)

        if self.with_label:
            self.INDEX = NumberField()
            self.fields = {
                'src': self.SRC,
                'tgt': self.TGT,
                'cue': self.CUE,
                'index': self.INDEX
            }
        else:
            self.fields = {'src': self.SRC, 'tgt': self.TGT, 'cue': self.CUE}

        # load vocab
        if not os.path.exists(self.prepared_vocab_file):
            self.build()
        else:
            self.load_vocab()
        self.padding_idx = self.TGT.stoi[self.TGT.pad_token]

        def src_filter_pred(src):
            """
            src_filter_pred
            """
            return min_len <= len(self.SRC.tokenize_fn(src)) <= max_len

        def tgt_filter_pred(tgt):
            """
            tgt_filter_pred
            """
            return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len

        self.filter_pred = lambda ex: src_filter_pred(ex[
            'src']) and tgt_filter_pred(ex['tgt'])

    def read_data(self, data_file, data_type="train"):
        """
        read_data:q
        """
        num = 0
        data = []
        with open(data_file, "r", encoding="utf-8") as f:
            for line in f:
                dialog = json.loads(line, encoding='utf-8')
                history = dialog["dialog"]
                uid = [int(i) for i in dialog["uid"]]
                profile = dialog["profile"]
                if "responder_profile" in dialog.keys():
                    responder_profile = dialog["responder_profile"]
                elif "response_profile" in dialog.keys():
                    responder_profile = dialog["response_profile"]
                else:
                    raise ValueError(
                        "No responder_profile or response_profile!")

                src = ""
                for idx, sent in zip(uid, history):
                    #tag_list = profile[idx]["tag"][0].split(';')
                    #loc_content = profile[idx]["loc"]
                    #tag_list.append(loc_content)
                    #tag_content = ' '.join(tag_list)
                    sent_content = sent[0]
                    #src += tag_content
                    #src += ' '
                    src += sent_content
                    src += ' '

                src = src.strip()
                tgt = dialog["golden_response"][0]
                filter_knowledge = []
                if type(responder_profile["tag"]) is list:
                    filter_knowledge.append(' '.join(
                        responder_profile["tag"][0].split(';')))
                else:
                    filter_knowledge.append(' '.join(
                        responder_profile["tag"].split(';')))
                filter_knowledge.append(responder_profile["loc"])
                data.append({'src': src, 'tgt': tgt, 'cue': filter_knowledge})

                num += 1
                if num < 10:
                    print("src:", src)
                    print("tgt:", tgt)
                    print("cue:", filter_knowledge)
                    print("\n")

        filtered_num = len(data)
        if not data_type == "test" and self.filter_pred is not None:
            data = [ex for ex in data if self.filter_pred(ex)]
        filtered_num -= len(data)
        print("Read {} {} examples ({} filtered)".format(
            len(data), data_type.upper(), filtered_num))
        return data
Пример #6
0
class PersonaCorpus(Corpus):
    """
    PersonaCorpus
    """
    def __init__(self,
                 data_dir,
                 data_prefix,
                 min_freq=0,
                 max_vocab_size=None,
                 min_len=0,
                 max_len=100,
                 embed_file=None,
                 share_vocab=False,
                 with_label=False):
        super(PersonaCorpus, self).__init__(data_dir=data_dir,
                                            data_prefix=data_prefix,
                                            min_freq=min_freq,
                                            max_vocab_size=max_vocab_size)
        self.min_len = min_len
        self.max_len = max_len
        self.share_vocab = share_vocab
        self.with_label = with_label

        self.SRC = TextField(tokenize_fn=tokenize, embed_file=embed_file)
        # self.LABEL = NumberField(dtype = float)
        # self.LABEL = NumberField(sequential=True, dtype = int)

        if self.share_vocab:
            self.TGT = self.SRC
            self.CUE = self.SRC
        else:
            self.TGT = TextField(tokenize_fn=tokenize, embed_file=embed_file)
            self.CUE = TextField(tokenize_fn=tokenize, embed_file=embed_file)

        if self.with_label:
            self.LABEL = NumberField(sequential=False, dtype=int)
            self.INDEX = NumberField(sequential=True, dtype=int)
            self.fields = {
                'src': self.SRC,
                'tgt': self.TGT,
                'cue': self.CUE,
                'label': self.LABEL,
                'index': self.INDEX
            }
        else:
            self.fields = {'src': self.SRC, 'tgt': self.TGT, 'cue': self.CUE}

        def src_filter_pred(src):
            """
            src_filter_pred
            """
            for sen in src:
                if not (min_len <= len(self.SRC.tokenize_fn(sen)) <= max_len):
                    return False
                else:
                    return True

        def tgt_filter_pred(tgt):
            """
            tgt_filter_pred
            """
            return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len

        self.filter_pred = lambda ex: src_filter_pred(ex[
            'src']) and tgt_filter_pred(ex['tgt'])

    def read_data(self, data_file, data_type="train"):
        """
        read_data
        """
        data = []
        with open(data_file, "r", encoding="utf-8") as f:
            for line in f:
                # print(self.with_label)
                if self.with_label:
                    query, response, personas, persona_label, key_index = line.strip(
                    ).split('\t')[:5]
                    filter_personas = []
                    for sent in personas.split('**'):
                        filter_personas.append(' '.join(
                            sent.split()[:self.max_len]))
                    index = key_index

                    data.append({
                        'src': query,
                        'tgt': response,
                        'cue': filter_personas,
                        'label': persona_label,
                        'index': index
                    })
                else:
                    queries, response, persona = line.strip().split('\t')[:3]
                    src = queries.split('**')
                    # filter_persona = ' '.join(persona.split()[:self.max_len])
                    filter_persona = persona
                    data.append({
                        'src': src,
                        'tgt': response,
                        'cue': filter_persona
                    })

        filtered_num = len(data)
        if self.filter_pred is not None:
            data = [ex for ex in data if self.filter_pred(ex)]
        filtered_num -= len(data)
        print("Read {} {} examples ({} filtered)".format(
            len(data), data_type.upper(), filtered_num))
        return data

    def read_data_multitask(self, data_file1, data_file2, data_type="train"):
        """
        read_data
        """
        data1 = []
        data2 = []
        with open(data_file2, "r", encoding="utf-8") as f:
            for line in f:
                # print(self.with_label)
                query, response, personas, persona_label, key_index = line.strip(
                ).split('\t')[:5]
                filter_personas = []
                for sent in personas.split('**'):
                    filter_personas.append(' '.join(
                        sent.split()[:self.max_len]))
                index = key_index

                data2.append({
                    'src': query,
                    'tgt': response,
                    'cue': filter_personas,
                    'label': persona_label,
                    'index': index
                })
        filtered_num = len(data2)
        if self.filter_pred is not None:
            data2 = [ex for ex in data2 if self.filter_pred(ex)]
        filtered_num -= len(data2)
        print("Read {} {} examples ({} filtered)".format(
            len(data2),
            data_type.upper() + 'task2', filtered_num))

        with open(data_file1, "r", encoding="utf-8") as f:
            for line in f:
                queries, response, persona = line.strip().split('\t')[:3]
                src = queries.split('**')
                # filter_persona = ' '.join(persona.split()[:self.max_len])
                filter_persona = persona
                data1.append({
                    'src': src,
                    'tgt': response,
                    'cue': filter_persona
                })
        filtered_num = len(data1)
        if self.filter_pred is not None:
            data1 = [ex for ex in data1 if self.filter_pred(ex)]
        filtered_num -= len(data1)
        print("Read {} {} examples ({} filtered)".format(
            len(data1),
            data_type.upper() + 'task1', filtered_num))

        return data1, data2