Exemplo n.º 1
0
    def __init__(self, config):
        super(SrcTgtCorpus, self).__init__(config)
        self.min_len = config.min_len
        self.max_len = config.max_len
        self.embed_file = config.embed_file
        self.share_vocab = config.share_vocab

        self.SRC = TextField(tokenize_fn=tokenize, embed_file=self.embed_file)
        if self.share_vocab:
            self.TGT = self.SRC
        else:
            self.TGT = TextField(tokenize_fn=tokenize,
                                 embed_file=self.embed_file)
        self.OUTPUT = self.TGT
        self.fields = {'src': self.SRC, 'tgt': self.TGT, 'output': self.OUTPUT}

        def src_filter_pred(src):
            return self.min_len <= len(
                self.SRC.tokenize_fn(src)) <= self.max_len

        def tgt_filter_pred(tgt):
            return self.min_len <= len(
                self.TGT.tokenize_fn(tgt)) <= self.max_len

        # self.filter_pred = lambda ex: src_filter_pred(ex['src']) and tgt_filter_pred(ex['tgt'])
        self.filter_pred = None
Exemplo n.º 2
0
    def __init__(self,
                 data_dir,
                 data_prefix,
                 min_freq=0,
                 max_vocab_size=None,
                 min_len=0,
                 max_len=100,
                 embed_file=None,
                 share_vocab=False):
        super(HieraSrcCorpus, self).__init__(data_dir=data_dir,
                                             data_prefix=data_prefix,
                                             min_freq=min_freq,
                                             max_vocab_size=max_vocab_size)
        self.min_len = min_len
        self.max_len = max_len
        self.share_vocab = share_vocab

        self.SRC = TextField(tokenize_fn=tokenize, embed_file=embed_file)
        if self.share_vocab:
            self.TGT = self.SRC
            self.CUE = self.SRC
        else:
            self.TGT = TextField(tokenize_fn=tokenize, embed_file=embed_file)
            self.CUE = TextField(tokenize_fn=tokenize, embed_file=embed_file)

        self.fields = {'src': self.SRC, 'tgt': self.TGT, 'cue': self.CUE}
Exemplo n.º 3
0
    def __init__(self,
                 data_dir,
                 data_prefix,
                 min_freq=0,
                 max_vocab_size=None,
                 min_len=0,
                 max_len=100,
                 embed_file=None,
                 share_vocab=False):
        super(SrcTgtCorpus, self).__init__(data_dir=data_dir,
                                           data_prefix=data_prefix,
                                           min_freq=min_freq,
                                           max_vocab_size=max_vocab_size)
        self.min_len = min_len
        self.max_len = max_len
        self.share_vocab = share_vocab

        self.SRC = TextField(tokenize_fn=tokenize,
                             embed_file=embed_file)
        if self.share_vocab:
            self.TGT = self.SRC
        else:
            self.TGT = TextField(tokenize_fn=tokenize,
                                 embed_file=embed_file)

        self.fields = {'src': self.SRC, 'tgt': self.TGT}

        def src_filter_pred(src):
            return min_len <= len(self.SRC.tokenize_fn(src)) <= max_len

        def tgt_filter_pred(tgt):
            return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len

        self.filter_pred = lambda ex: src_filter_pred(
            ex['src']) and tgt_filter_pred(ex['tgt'])
Exemplo n.º 4
0
class SrcTgtCorpus(Corpus):
    def __init__(self,
                 data_dir,
                 data_prefix,
                 min_freq=0,
                 max_vocab_size=None,
                 min_len=0,
                 max_len=100,
                 embed_file=None,
                 share_vocab=False):
        super(SrcTgtCorpus, self).__init__(data_dir=data_dir,
                                           data_prefix=data_prefix,
                                           min_freq=min_freq,
                                           max_vocab_size=max_vocab_size)
        self.min_len = min_len
        self.max_len = max_len
        self.share_vocab = share_vocab

        self.SRC = TextField(tokenize_fn=tokenize, embed_file=embed_file)
        if self.share_vocab:
            self.TGT = self.SRC
        else:
            self.TGT = TextField(tokenize_fn=tokenize, embed_file=embed_file)

        self.fields = {'src': self.SRC, 'tgt': self.TGT}

        def src_filter_pred(src):

            return min_len <= len(self.SRC.tokenize_fn(src)) <= max_len

        def tgt_filter_pred(tgt):

            return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len

        self.filter_pred = lambda ex: src_filter_pred(ex[
            'src']) and tgt_filter_pred(ex['tgt'])

    def read_data(self, data_file, data_type="train"):

        data = []
        filtered = 0
        with open(data_file, "r", encoding="utf-8") as f:
            for line in f:
                src, tgt = line.strip().split('\t')[:2]
                data.append({'src': src, 'tgt': tgt})

        filtered_num = len(data)
        if self.filter_pred is not None:
            data = [ex for ex in data if self.filter_pred(ex)]
        filtered_num -= len(data)
        print("Read {} {} examples ({} filtered)".format(
            len(data), data_type.upper(), filtered_num))
        return data
Exemplo n.º 5
0
    def __init__(self,
                 data_dir,
                 data_prefix,
                 min_freq=0,
                 max_vocab_size=None,
                 min_len=0,
                 max_len=100,
                 embed_file=None,
                 share_vocab=False,
                 with_label=False):
        super(KnowledgeCorpus, self).__init__(data_dir=data_dir,
                                              data_prefix=data_prefix,
                                              min_freq=min_freq,
                                              max_vocab_size=max_vocab_size)
        self.min_len = min_len
        self.max_len = max_len
        self.share_vocab = share_vocab
        self.with_label = with_label

        self.SRC = TextField(tokenize_fn=tokenize, embed_file=embed_file)
        if self.share_vocab:
            self.TGT = self.SRC
            self.CUE = self.SRC
        else:
            self.TGT = TextField(tokenize_fn=tokenize, embed_file=embed_file)
            self.CUE = TextField(tokenize_fn=tokenize, embed_file=embed_file)

        if self.with_label:
            self.INDEX = NumberField()
            self.fields = {
                'src': self.SRC,
                'tgt': self.TGT,
                'cue': self.CUE,
                'index': self.INDEX
            }
        else:
            self.fields = {'src': self.SRC, 'tgt': self.TGT, 'cue': self.CUE}

        def src_filter_pred(src):
            """
            src_filter_pred
            """
            return min_len <= len(self.SRC.tokenize_fn(src)) <= max_len

        def tgt_filter_pred(tgt):
            """
            tgt_filter_pred
            """
            return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len

        self.filter_pred = lambda ex: src_filter_pred(ex[
            'src']) and tgt_filter_pred(ex['tgt'])
Exemplo n.º 6
0
class SrcTgtCorpus(Corpus):
    """
    SrcTgtCorpus
    """
    def __init__(self, config):
        super(SrcTgtCorpus, self).__init__(config)
        self.min_len = config.min_len
        self.max_len = config.max_len
        self.embed_file = config.embed_file
        self.share_vocab = config.share_vocab

        self.SRC = TextField(tokenize_fn=tokenize, embed_file=self.embed_file)
        if self.share_vocab:
            self.TGT = self.SRC
        else:
            self.TGT = TextField(tokenize_fn=tokenize,
                                 embed_file=self.embed_file)
        self.OUTPUT = self.TGT
        self.fields = {'src': self.SRC, 'tgt': self.TGT, 'output': self.OUTPUT}

        def src_filter_pred(src):
            return self.min_len <= len(
                self.SRC.tokenize_fn(src)) <= self.max_len

        def tgt_filter_pred(tgt):
            return self.min_len <= len(
                self.TGT.tokenize_fn(tgt)) <= self.max_len

        # self.filter_pred = lambda ex: src_filter_pred(ex['src']) and tgt_filter_pred(ex['tgt'])
        self.filter_pred = None

    def read_data(self, data_file, data_type="train"):
        data = []
        filtered = 0

        with open(data_file, "r", encoding="utf-8") as f:
            for line in f:
                src = json.loads(line.strip())['src']
                tgt = json.loads(line.strip())['tgt']
                data.append({'src': src, 'tgt': tgt})

        filtered_num = len(data)
        if self.filter_pred is not None:
            data = [ex for ex in data if self.filter_pred(ex)]
        filtered_num -= len(data)
        print("Read {} {} examples ({} filtered)".format(
            len(data), data_type.upper(), filtered_num))
        return data
Exemplo n.º 7
0
    def __init__(self,
                 data_dir,
                 data_prefix,
                 min_freq=0,
                 entity_file=None,
                 max_vocab_size=None):
        self.data_dir = data_dir
        self.data_prefix = data_prefix
        self.min_freq = min_freq
        self.max_vocab_size = max_vocab_size

        prepared_data_file = data_prefix + "_" + str(
            max_vocab_size) + ".data.pt"
        prepared_vocab_file = data_prefix + "_" + str(
            max_vocab_size) + ".vocab.pt"

        self.prepared_data_file = os.path.join(data_dir, prepared_data_file)
        self.prepared_vocab_file = os.path.join(data_dir, prepared_vocab_file)
        self.SRC = TextField(entiy_dict_file=entity_file)
        self.filter_pred = None
        self.sort_fn = None
        self.data = None
Exemplo n.º 8
0
    def __init__(self, config):
        super(TopicGuide2Corpus, self).__init__(config)
        self.min_len = config.min_len
        self.max_len = config.max_len
        self.share_vocab = config.share_vocab
        self.embed_file = config.embed_file
        self.topic_words_num = config.topic_words_num
        self.topic_vocab_file = config.topic_vocab_file

        self.SRC = TextField(tokenize_fn=tokenize, embed_file=self.embed_file)

        self.TGT = TextField(tokenize_fn=tokenize, embed_file=self.embed_file)

        self.LABEL = NumberField(dtype=int)
        self.OUTPUT = TextField(tokenize_fn=tokenize,
                                embed_file=self.embed_file)

        self.TOPIC = FixedField(fix_file=self.topic_vocab_file,
                                embed_file=self.embed_file)

        self.fields = {
            'src': self.SRC,
            'tgt': self.TGT,
            'output': self.OUTPUT,
            'topic_src_label': self.LABEL,
            'topic_tgt_label': self.LABEL,
            'topic': self.TOPIC,
        }

        def src_filter_pred(src):
            return min_len <= len(self.SRC.tokenize_fn(src)) <= max_len

        def tgt_filter_pred(tgt):
            return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len

        # self.filter_pred = lambda ex: src_filter_pred(ex['src']) and tgt_filter_pred(ex['tgt'])
        self.filter_pred = None
Exemplo n.º 9
0
class TopicGuide2Corpus(Corpus):
    """
    CueCorpus
    """
    def __init__(self, config):
        super(TopicGuide2Corpus, self).__init__(config)
        self.min_len = config.min_len
        self.max_len = config.max_len
        self.share_vocab = config.share_vocab
        self.embed_file = config.embed_file
        self.topic_words_num = config.topic_words_num
        self.topic_vocab_file = config.topic_vocab_file

        self.SRC = TextField(tokenize_fn=tokenize, embed_file=self.embed_file)

        self.TGT = TextField(tokenize_fn=tokenize, embed_file=self.embed_file)

        self.LABEL = NumberField(dtype=int)
        self.OUTPUT = TextField(tokenize_fn=tokenize,
                                embed_file=self.embed_file)

        self.TOPIC = FixedField(fix_file=self.topic_vocab_file,
                                embed_file=self.embed_file)

        self.fields = {
            'src': self.SRC,
            'tgt': self.TGT,
            'output': self.OUTPUT,
            'topic_src_label': self.LABEL,
            'topic_tgt_label': self.LABEL,
            'topic': self.TOPIC,
        }

        def src_filter_pred(src):
            return min_len <= len(self.SRC.tokenize_fn(src)) <= max_len

        def tgt_filter_pred(tgt):
            return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len

        # self.filter_pred = lambda ex: src_filter_pred(ex['src']) and tgt_filter_pred(ex['tgt'])
        self.filter_pred = None

    def read_data(self, data_file, data_type="train"):
        """
        read_data
        """
        data = []
        with open(data_file, "r", encoding="utf-8") as f:
            for line in f:
                line = json.loads(line.strip())
                src = line['src']
                tgt = line['tgt']
                topic_src_label = line['topic_src_label']
                topic_tgt_label = line['topic_tgt_label']

                data.append({
                    'src': src,
                    'tgt': tgt,
                    'output': tgt,
                    'topic_src_label': topic_src_label,
                    'topic_tgt_label': topic_tgt_label,
                    'topic': src,
                })

        filtered_num = len(data)
        if self.filter_pred is not None:
            data = [ex for ex in data if self.filter_pred(ex)]
        filtered_num -= len(data)
        print("Read {} {} examples ({} filtered)".format(
            len(data), data_type.upper(), filtered_num))
        return data

    def build_vocab(self, data):
        """
        Args
        ----
        data: ``List[Dict]``
        """
        field_data_dict = {}
        for name in data[0].keys():
            field = self.fields.get(name)
            if isinstance(field, TextField):
                xs = [x[name] for x in data]
                if field not in field_data_dict:
                    field_data_dict[field] = xs
                else:
                    field_data_dict[field] += xs

        vocab_dict = {}
        field_dict = {}
        for name, field in self.fields.items():
            if name == 'topic':
                field.build_vocab()
                field_dict[name] = field
                continue

            if field in field_data_dict:
                print("Building vocabulary of field {} ...".format(
                    name.upper()))
                if field.vocab_size == 0:
                    if name != 'output':
                        field.build_vocab(field_data_dict[field],
                                          min_freq=self.min_freq,
                                          max_size=self.max_vocab_size)
                    field_dict[name] = field

        field_dict['output'].add_with_other_field(field_dict['tgt'])
        field_dict['output'].add_with_other_field(field_dict['topic'])
        if self.embed_file is not None:
            field_dict['output'].embeddings = field_dict[
                'output'].build_word_embeddings(self.embed_file)

        for name, field in field_dict.items():
            vocab_dict[name] = field.dump_vocab()

        return vocab_dict

    def load_vocab(self, prepared_vocab_file):
        super().load_vocab(prepared_vocab_file)

        self.topic_bow_vocab_size = self.TOPIC.vocab_size
        self.Dataset = lambda x: WithBowDataset(
            data=x, bow_vocab_size=self.topic_bow_vocab_size)
Exemplo n.º 10
0
class KnowledgeCorpus(Corpus):
    def __init__(self,
                 data_dir,
                 data_prefix,
                 min_freq=0,
                 max_vocab_size=None,
                 min_len=0,
                 max_len=100,
                 embed_file=None,
                 share_vocab=False,
                 with_label=False):
        super(KnowledgeCorpus, self).__init__(data_dir=data_dir,
                                              data_prefix=data_prefix,
                                              min_freq=min_freq,
                                              max_vocab_size=max_vocab_size)
        self.min_len = min_len
        self.max_len = max_len
        self.share_vocab = share_vocab
        self.with_label = with_label

        self.SRC = TextField(tokenize_fn=tokenize,
                             embed_file=embed_file)
        if self.share_vocab:
            self.TGT = self.SRC
            self.CUE = self.SRC
        else:
            self.TGT = TextField(tokenize_fn=tokenize,
                                 embed_file=embed_file)
            self.CUE = TextField(tokenize_fn=tokenize,
                                 embed_file=embed_file)

        if self.with_label:
            self.INDEX = NumberField()
            self.fields = {'src': self.SRC, 'tgt': self.TGT, 'cue': self.CUE, 'index': self.INDEX}
        else:
            self.fields = {'src': self.SRC, 'tgt': self.TGT, 'cue': self.CUE}

        def src_filter_pred(src):
            return min_len <= len(self.SRC.tokenize_fn(src)) <= max_len

        def tgt_filter_pred(tgt):
            return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len

        self.filter_pred = lambda ex: src_filter_pred(
            ex['src']) and tgt_filter_pred(ex['tgt'])

    def read_data(self, data_file, data_type="train"):
        data = []
        with open(data_file, "r", encoding="utf-8") as f:
            for line in f:
                if self.with_label:
                    src, tgt, knowledge, label = line.strip().split('\t')[:4]
                    filter_knowledge = []
                    for sent in knowledge.split(''):
                        filter_knowledge.append(' '.join(sent.split()[:self.max_len]))
                    data.append({'src': src, 'tgt': tgt, 'cue': filter_knowledge, 'index': label})
                else:
                    src, tgt, knowledge = line.strip().split('\t')[:3]
                    filter_knowledge = []
                    for sent in knowledge.split(''):
                        filter_knowledge.append(' '.join(sent.split()[:self.max_len]))
                    data.append({'src': src, 'tgt': tgt, 'cue':filter_knowledge})

        filtered_num = len(data)
        if self.filter_pred is not None:
            data = [ex for ex in data if self.filter_pred(ex)]
        filtered_num -= len(data)
        print(
            f"Read {len(data)} {data_type.upper()} examples ({filtered_num} filtered)")
        return data
Exemplo n.º 11
0
class KnowledgeCorpus(Corpus):
    """
    KnowledgeCorpus
    """
    def __init__(self,
                 data_dir,
                 data_prefix,
                 min_freq=0,
                 max_vocab_size=None,
                 vocab_file=None,
                 min_len=0,
                 max_len=100,
                 embed_file=None,
                 share_vocab=False,
                 with_label=False):
        super(KnowledgeCorpus, self).__init__(data_dir=data_dir,
                                              data_prefix=data_prefix,
                                              min_freq=min_freq,
                                              max_vocab_size=max_vocab_size,
                                              vocab_file=vocab_file)
        self.min_len = min_len
        self.max_len = max_len
        self.share_vocab = share_vocab
        self.with_label = with_label

        self.SRC = TextField(tokenize_fn=tokenize, embed_file=embed_file)
        if self.share_vocab:
            self.TGT = self.SRC
            self.CUE = self.SRC
        else:
            self.TGT = TextField(tokenize_fn=tokenize, embed_file=embed_file)
            self.CUE = TextField(tokenize_fn=tokenize, embed_file=embed_file)

        if self.with_label:
            self.INDEX = NumberField()
            self.fields = {
                'src': self.SRC,
                'tgt': self.TGT,
                'cue': self.CUE,
                'index': self.INDEX
            }
        else:
            self.fields = {'src': self.SRC, 'tgt': self.TGT, 'cue': self.CUE}

        # load vocab
        if not os.path.exists(self.prepared_vocab_file):
            self.build()
        else:
            self.load_vocab()
        self.padding_idx = self.TGT.stoi[self.TGT.pad_token]

        def src_filter_pred(src):
            """
            src_filter_pred
            """
            return min_len <= len(self.SRC.tokenize_fn(src)) <= max_len

        def tgt_filter_pred(tgt):
            """
            tgt_filter_pred
            """
            return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len

        self.filter_pred = lambda ex: src_filter_pred(ex[
            'src']) and tgt_filter_pred(ex['tgt'])

    def read_data(self, data_file, data_type="train"):
        """
        read_data:q
        """
        num = 0
        data = []
        with open(data_file, "r", encoding="utf-8") as f:
            for line in f:
                dialog = json.loads(line, encoding='utf-8')
                history = dialog["dialog"]
                uid = [int(i) for i in dialog["uid"]]
                profile = dialog["profile"]
                if "responder_profile" in dialog.keys():
                    responder_profile = dialog["responder_profile"]
                elif "response_profile" in dialog.keys():
                    responder_profile = dialog["response_profile"]
                else:
                    raise ValueError(
                        "No responder_profile or response_profile!")

                src = ""
                for idx, sent in zip(uid, history):
                    #tag_list = profile[idx]["tag"][0].split(';')
                    #loc_content = profile[idx]["loc"]
                    #tag_list.append(loc_content)
                    #tag_content = ' '.join(tag_list)
                    sent_content = sent[0]
                    #src += tag_content
                    #src += ' '
                    src += sent_content
                    src += ' '

                src = src.strip()
                tgt = dialog["golden_response"][0]
                filter_knowledge = []
                if type(responder_profile["tag"]) is list:
                    filter_knowledge.append(' '.join(
                        responder_profile["tag"][0].split(';')))
                else:
                    filter_knowledge.append(' '.join(
                        responder_profile["tag"].split(';')))
                filter_knowledge.append(responder_profile["loc"])
                data.append({'src': src, 'tgt': tgt, 'cue': filter_knowledge})

                num += 1
                if num < 10:
                    print("src:", src)
                    print("tgt:", tgt)
                    print("cue:", filter_knowledge)
                    print("\n")

        filtered_num = len(data)
        if not data_type == "test" and self.filter_pred is not None:
            data = [ex for ex in data if self.filter_pred(ex)]
        filtered_num -= len(data)
        print("Read {} {} examples ({} filtered)".format(
            len(data), data_type.upper(), filtered_num))
        return data
Exemplo n.º 12
0
class Entity_Corpus_pos(object):
    """
    Corpus
    """
    def __init__(self,
                 data_dir,
                 data_prefix,
                 min_freq=0,
                 entity_file=None,
                 max_vocab_size=None):
        self.data_dir = data_dir
        self.data_prefix = data_prefix
        self.min_freq = min_freq
        self.max_vocab_size = max_vocab_size

        prepared_data_file = data_prefix + "_" + str(
            max_vocab_size) + ".data.pt"
        prepared_vocab_file = data_prefix + "_" + str(
            max_vocab_size) + ".vocab.pt"

        self.prepared_data_file = os.path.join(data_dir, prepared_data_file)
        self.prepared_vocab_file = os.path.join(data_dir, prepared_vocab_file)
        self.SRC = TextField(entiy_dict_file=entity_file)
        self.POS = TextField(bos_token=None, eos_token=None)

        self.filter_pred = None
        self.sort_fn = None
        self.data = None

    def load(self):
        """
        load
        """
        if not (os.path.exists(self.prepared_data_file)
                and os.path.exists(self.prepared_vocab_file)):
            self.build()
        self.load_vocab(self.prepared_vocab_file)
        self.load_data(self.prepared_data_file)

        self.padding_idx = self.SRC.stoi[self.SRC.pad_token]

    def reload(self, data_type='test'):
        """
        reload
        """
        data_file = os.path.join(self.data_dir,
                                 self.data_prefix + "." + data_type)
        if os.path.exists(data_file):
            self.data[data_type] = Entity_Dataset_pos(torch.load(data_file))
        else:
            data_file_raw = data_file + '.raw'
            data_raw = self.read_data(data_file_raw,
                                      data_type="test",
                                      has_id=True)
            data_examples = self.build_examples(data_raw)
            torch.save(data_examples, data_file)
            self.data[data_type] = Entity_Dataset_pos(data_examples)

        print(
            "Number of examples:", " ".join("{}-{}".format(k.upper(), len(v))
                                            for k, v in self.data.items()))

    def load_data(self, prepared_data_file=None):
        """
        load_data
        """
        prepared_data_file = prepared_data_file or self.prepared_data_file
        print("Loading prepared data from {} ...".format(prepared_data_file))
        data = torch.load(prepared_data_file)
        self.data = {
            "train": Entity_Dataset_pos(data['train']),
            "valid": Entity_Dataset_pos(data["valid"]),
        }
        print(
            "Number of examples:", " ".join("{}-{}".format(k.upper(), len(v))
                                            for k, v in self.data.items()))

    def load_vocab(self, prepared_vocab_file):
        """
        load_vocab
        """
        prepared_vocab_file = prepared_vocab_file or self.prepared_vocab_file
        print("Loading prepared vocab from {} ...".format(prepared_vocab_file))
        vocab_dict = torch.load(prepared_vocab_file)
        self.SRC.load_vocab(vocab_dict['src'])
        self.POS.load_vocab(vocab_dict['pos'])
        print('Finish loading vocab , size: %d' % (self.SRC.vocab_size))
        print('Finish loading pos vocab , size: %d' % (self.POS.vocab_size))

    def read_data(self, data_file, data_type="train", has_id=False):
        """
        read_data
        """
        data = []
        with open(data_file, "r", encoding="utf-8") as f:
            import os
            LTP_DATA_DIR = './extend/ltp_data_v3.4.0'
            from pyltp import Postagger
            pos_model_path = os.path.join(LTP_DATA_DIR, 'pos.model')
            postagger = Postagger()  # 初始化实例
            postagger.load(pos_model_path)  # 加载模型
            for line in f:
                line = json.loads(line)
                src = line['text']
                mask = line['mask']
                true_Entity = line['true_Entity']

                pos = list(postagger.postag(src))

                assert len(src) == len(pos)
                # 建立 tgt
                tgt = []
                # 添加一个开始表示符
                pre = '<bos>'
                for item in true_Entity:
                    tgt.append([pre, item[0], item[1]['emotion']])
                    pre = item[1]['entity']
                tgt.append([pre, len(src) - 1, 'NORM'])
                #  指向最后一位
                d = {'src': src, 'mask': mask, 'tgt': tgt, 'pos': pos}
                if has_id:
                    d['id'] = line['newsId']
                data.append(d)
        # 划分一下 训练验证
        print('finished data read')
        return data

    def build_vocab(self, data):
        """
        Args
        ----
        data: ``List[Dict]``
        """

        temp1 = [x['src'] for x in data]
        self.SRC.build_vocab(temp1,
                             min_freq=self.min_freq,
                             max_size=self.max_vocab_size)
        temp2 = [x['pos'] for x in data]
        self.POS.build_vocab(temp2,
                             min_freq=self.min_freq,
                             max_size=self.max_vocab_size)
        d = {}
        d['src'] = self.SRC.dump_vocab()
        d['pos'] = self.POS.dump_vocab()

        return d

    def build_examples(self, data):
        """
        Args
        ----
        data: ``List[Dict]``

        raw_data  :
        text
        mask
        tgt

        """
        examples = []
        for raw_data in tqdm(data):
            example = {}
            raw_text = raw_data['src']
            num_text = self.SRC.str2num(raw_text)

            example['num_pos'] = self.POS.easy_str2num(raw_data['pos'])
            example['num_src'] = num_text
            example['raw_src'] = raw_text

            example['mask'] = raw_data['mask']
            # raw_tgt,tgt_output,tgt_emo=zip(*raw_data['tgt'])

            # tgt_input = []
            # tgt_output=list(tgt_output)
            # tgt_input.append(self.SRC.word2num(raw_tgt[0]))
            # tgt_input+=[num_text[x] for x in tgt_output[:-1]]
            # tgt_emo=[self.SRC.emotoi.get(emotion, 0) for emotion in tgt_emo]
            # assert len(tgt_input) ==len(tgt_output)

            tgt_input = []
            tgt_output = []
            raw_tgt = []
            tgt_emo = []

            for [input, output, emotion] in raw_data['tgt']:
                raw_tgt.append(input)
                tgt_input.append(self.SRC.word2num(input))
                # tgt_input.append(self.SRC.target2num(input))
                tgt_emo.append(self.SRC.emotoi.get(emotion, 0))
                # if input in self.SRC.stoi:
                #     tgt_input.append(self.SRC.stoi[input])
                # elif input in self.SRC.entiy_dict:
                #     tgt_input.append(self.SRC.stoi[self.SRC.entiy_token])
                # else:
                #     tgt_input.append(self.SRC.stoi[self.SRC.unk_token])
                tgt_output.append(output)

            example['num_tgt_input'] = tgt_input
            example['tgt_output'] = tgt_output
            example['tgt_emo'] = tgt_emo
            example['raw_tgt'] = raw_tgt
            if 'id' in raw_data:
                example['id'] = raw_data['id']
            examples.append(example)
        if self.sort_fn is not None:
            print("Sorting examples ...")
            examples = self.sort_fn(examples)
        return examples

    def build(self):
        """
        build
        """
        print("Start to build corpus!")
        train_file = os.path.join(self.data_dir, self.data_prefix + ".train")

        print("Reading data ...")
        train_raw = self.read_data(train_file, data_type="train")
        if os.path.exists(self.prepared_vocab_file):
            print('加载旧字典')
            self.load_vocab(self.prepared_vocab_file)
        else:
            vocab = self.build_vocab(train_raw)
        # vocab = self.build_vocab(train_raw)

        print("Building TRAIN examples ...")
        train_data = self.build_examples(train_raw)
        import random
        random.shuffle(train_data)
        # train_data=train_data
        # valid_data=train_data
        valid_data = train_data[:2000]
        train_data = train_data[2000:]

        data = {
            "train": train_data,
            "valid": valid_data,
        }

        print('num_train_data %d, num_valid_data %d' %
              (len(train_data), len(valid_data)))

        if not os.path.exists(self.prepared_vocab_file):
            print("Saving prepared vocab ...")
            torch.save(vocab, self.prepared_vocab_file)
            print("Saved prepared vocab to '{}'".format(
                self.prepared_vocab_file))
        # print("Saving prepared vocab ...")
        # torch.save(vocab, self.prepared_vocab_file)
        # print("Saved prepared vocab to '{}'".format(self.prepared_vocab_file))

        print("Saving prepared data ...")
        torch.save(data, self.prepared_data_file)
        print("Saved prepared data to '{}'".format(self.prepared_data_file))

    def create_batches(self,
                       batch_size,
                       data_type="train",
                       shuffle=False,
                       device=None):
        """
        create_batches
        """
        try:
            data = self.data[data_type]
            data_loader = data.create_batches(batch_size, shuffle, device)
            return data_loader
        except KeyError:
            raise KeyError("Unsported data type: {}!".format(data_type))

    def transform(self,
                  data_file,
                  batch_size,
                  data_type="test",
                  shuffle=False,
                  device=None):
        """
        Transform raw text from data_file to Dataset and create data loader.
        """
        raw_data = self.read_data(data_file, data_type=data_type)
        examples = self.build_examples(raw_data)
        data = Dataset(examples)
        data_loader = data.create_batches(batch_size, shuffle, device)
        return data_loader
Exemplo n.º 13
0
    def __init__(self,
                 data_dir,
                 data_prefix,
                 min_freq=0,
                 max_vocab_size=None,
                 min_len=0,
                 max_len=100,
                 embed_file=None,
                 share_vocab=False,
                 with_label=False):
        super(PersonaCorpus, self).__init__(data_dir=data_dir,
                                            data_prefix=data_prefix,
                                            min_freq=min_freq,
                                            max_vocab_size=max_vocab_size)
        self.min_len = min_len
        self.max_len = max_len
        self.share_vocab = share_vocab
        self.with_label = with_label

        self.SRC = TextField(tokenize_fn=tokenize, embed_file=embed_file)
        # self.LABEL = NumberField(dtype = float)
        # self.LABEL = NumberField(sequential=True, dtype = int)

        if self.share_vocab:
            self.TGT = self.SRC
            self.CUE = self.SRC
        else:
            self.TGT = TextField(tokenize_fn=tokenize, embed_file=embed_file)
            self.CUE = TextField(tokenize_fn=tokenize, embed_file=embed_file)

        if self.with_label:
            self.LABEL = NumberField(sequential=False, dtype=int)
            self.INDEX = NumberField(sequential=True, dtype=int)
            self.fields = {
                'src': self.SRC,
                'tgt': self.TGT,
                'cue': self.CUE,
                'label': self.LABEL,
                'index': self.INDEX
            }
        else:
            self.fields = {'src': self.SRC, 'tgt': self.TGT, 'cue': self.CUE}

        def src_filter_pred(src):
            """
            src_filter_pred
            """
            for sen in src:
                if not (min_len <= len(self.SRC.tokenize_fn(sen)) <= max_len):
                    return False
                else:
                    return True

        def tgt_filter_pred(tgt):
            """
            tgt_filter_pred
            """
            return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len

        self.filter_pred = lambda ex: src_filter_pred(ex[
            'src']) and tgt_filter_pred(ex['tgt'])
Exemplo n.º 14
0
class PersonaCorpus(Corpus):
    """
    PersonaCorpus
    """
    def __init__(self,
                 data_dir,
                 data_prefix,
                 min_freq=0,
                 max_vocab_size=None,
                 min_len=0,
                 max_len=100,
                 embed_file=None,
                 share_vocab=False,
                 with_label=False):
        super(PersonaCorpus, self).__init__(data_dir=data_dir,
                                            data_prefix=data_prefix,
                                            min_freq=min_freq,
                                            max_vocab_size=max_vocab_size)
        self.min_len = min_len
        self.max_len = max_len
        self.share_vocab = share_vocab
        self.with_label = with_label

        self.SRC = TextField(tokenize_fn=tokenize, embed_file=embed_file)
        # self.LABEL = NumberField(dtype = float)
        # self.LABEL = NumberField(sequential=True, dtype = int)

        if self.share_vocab:
            self.TGT = self.SRC
            self.CUE = self.SRC
        else:
            self.TGT = TextField(tokenize_fn=tokenize, embed_file=embed_file)
            self.CUE = TextField(tokenize_fn=tokenize, embed_file=embed_file)

        if self.with_label:
            self.LABEL = NumberField(sequential=False, dtype=int)
            self.INDEX = NumberField(sequential=True, dtype=int)
            self.fields = {
                'src': self.SRC,
                'tgt': self.TGT,
                'cue': self.CUE,
                'label': self.LABEL,
                'index': self.INDEX
            }
        else:
            self.fields = {'src': self.SRC, 'tgt': self.TGT, 'cue': self.CUE}

        def src_filter_pred(src):
            """
            src_filter_pred
            """
            for sen in src:
                if not (min_len <= len(self.SRC.tokenize_fn(sen)) <= max_len):
                    return False
                else:
                    return True

        def tgt_filter_pred(tgt):
            """
            tgt_filter_pred
            """
            return min_len <= len(self.TGT.tokenize_fn(tgt)) <= max_len

        self.filter_pred = lambda ex: src_filter_pred(ex[
            'src']) and tgt_filter_pred(ex['tgt'])

    def read_data(self, data_file, data_type="train"):
        """
        read_data
        """
        data = []
        with open(data_file, "r", encoding="utf-8") as f:
            for line in f:
                # print(self.with_label)
                if self.with_label:
                    query, response, personas, persona_label, key_index = line.strip(
                    ).split('\t')[:5]
                    filter_personas = []
                    for sent in personas.split('**'):
                        filter_personas.append(' '.join(
                            sent.split()[:self.max_len]))
                    index = key_index

                    data.append({
                        'src': query,
                        'tgt': response,
                        'cue': filter_personas,
                        'label': persona_label,
                        'index': index
                    })
                else:
                    queries, response, persona = line.strip().split('\t')[:3]
                    src = queries.split('**')
                    # filter_persona = ' '.join(persona.split()[:self.max_len])
                    filter_persona = persona
                    data.append({
                        'src': src,
                        'tgt': response,
                        'cue': filter_persona
                    })

        filtered_num = len(data)
        if self.filter_pred is not None:
            data = [ex for ex in data if self.filter_pred(ex)]
        filtered_num -= len(data)
        print("Read {} {} examples ({} filtered)".format(
            len(data), data_type.upper(), filtered_num))
        return data

    def read_data_multitask(self, data_file1, data_file2, data_type="train"):
        """
        read_data
        """
        data1 = []
        data2 = []
        with open(data_file2, "r", encoding="utf-8") as f:
            for line in f:
                # print(self.with_label)
                query, response, personas, persona_label, key_index = line.strip(
                ).split('\t')[:5]
                filter_personas = []
                for sent in personas.split('**'):
                    filter_personas.append(' '.join(
                        sent.split()[:self.max_len]))
                index = key_index

                data2.append({
                    'src': query,
                    'tgt': response,
                    'cue': filter_personas,
                    'label': persona_label,
                    'index': index
                })
        filtered_num = len(data2)
        if self.filter_pred is not None:
            data2 = [ex for ex in data2 if self.filter_pred(ex)]
        filtered_num -= len(data2)
        print("Read {} {} examples ({} filtered)".format(
            len(data2),
            data_type.upper() + 'task2', filtered_num))

        with open(data_file1, "r", encoding="utf-8") as f:
            for line in f:
                queries, response, persona = line.strip().split('\t')[:3]
                src = queries.split('**')
                # filter_persona = ' '.join(persona.split()[:self.max_len])
                filter_persona = persona
                data1.append({
                    'src': src,
                    'tgt': response,
                    'cue': filter_persona
                })
        filtered_num = len(data1)
        if self.filter_pred is not None:
            data1 = [ex for ex in data1 if self.filter_pred(ex)]
        filtered_num -= len(data1)
        print("Read {} {} examples ({} filtered)".format(
            len(data1),
            data_type.upper() + 'task1', filtered_num))

        return data1, data2