예제 #1
0
파일: dataset.py 프로젝트: zwjyyc/BiBloSA
    def process_raw_dataset(raw_data, data_type):
        _logger.add()
        _logger.add('processing raw data: %s...' % data_type)
        for topic in tqdm(raw_data):
            for paragraph in topic['paragraphs']:
                # context
                paragraph['context'] = paragraph['context'].replace(
                    "''", '" ').replace("``", '" ')
                paragraph['context_token'] = [[
                    token.replace("''", '"').replace("``", '"')
                    for token in nltk.word_tokenize(sent)
                ] for sent in nltk.sent_tokenize(paragraph['context'])]
                paragraph['context_token'] = [
                    Dataset.further_tokenize(sent)
                    for sent in paragraph['context_token']
                ]

                # qas
                for qa in paragraph['qas']:
                    qa['question'] = qa['question'].replace("''",
                                                            '" ').replace(
                                                                "``", '" ')
                    qa['question_token'] = Dataset.further_tokenize([
                        token.replace("''", '"').replace("``", '"')
                        for token in nltk.word_tokenize(qa['question'])
                    ])
                    # # tag generation
                    for answer in qa['answers']:
                        answer[
                            'sent_label'] = Dataset.sentence_label_generation(
                                paragraph['context'],
                                paragraph['context_token'], answer['text'],
                                answer['answer_start'])
        _logger.done()
        return raw_data
예제 #2
0
    def restore(self, sess):
        _logger.add()
        # print(cfg.ckpt_dir)

        if cfg.load_step is None:
            if cfg.load_path is None:
                _logger.add('trying to restore from dir %s' % cfg.ckpt_dir)
                latest_checkpoint_path = tf.train.latest_checkpoint(
                    cfg.ckpt_dir)
            else:
                latest_checkpoint_path = cfg.load_path
        else:
            latest_checkpoint_path = cfg.ckpt_path + '-' + str(cfg.load_step)

        if latest_checkpoint_path is not None:
            _logger.add('trying to restore from ckpt file %s' %
                        latest_checkpoint_path)
            try:
                self.saver.restore(sess, latest_checkpoint_path)
                _logger.add('success to restore')
            except tf.errors.NotFoundError:
                _logger.add('failure to restore')
                if cfg.mode != 'train':
                    raise FileNotFoundError('canot find model file')
        else:
            _logger.add('No check point file in dir %s ' % cfg.ckpt_dir)
            if cfg.mode != 'train':
                raise FileNotFoundError('canot find model file')

        _logger.done()
예제 #3
0
파일: dataset.py 프로젝트: zwjyyc/BiBloSA
    def digitize_dataset(dataset, dicts, data_type):
        token2index = dict([
            (token, idx)
            for idx, token in enumerate(dicts['token'] + dicts['glove'])
        ])

        def digitize_token(token):
            token = token if not cfg.lower_word else token.lower()
            try:
                return token2index[token]
            except KeyError:
                return 1

        _logger.add()
        _logger.add('digitizing data: %s...' % data_type)

        for topic in tqdm(dataset):
            for paragraph in topic['paragraphs']:
                paragraph['context_token_digital'] = [[
                    digitize_token(token) for token in sent
                ] for sent in paragraph['context_token']]
                for qa in paragraph['qas']:
                    qa['question_token_digital'] = [
                        digitize_token(token) for token in qa['question_token']
                    ]
        _logger.done()
        return dataset
예제 #4
0
파일: model_disan.py 프로젝트: zxsted/DiSAN
    def build_network(self):
        _logger.add()
        _logger.add('building %s neural network structure...' % cfg.network_type)
        tds, cds = self.tds, self.cds
        tl = self.tl
        tel, cel, cos, ocd, fh = self.tel, self.cel, self.cos, self.ocd, self.fh
        hn = self.hn
        bs, sl, ol, mc = self.bs, self.sl, self.ol, self.mc

        with tf.variable_scope('emb'):
            token_emb_mat = generate_embedding_mat(tds, tel, init_mat=self.token_emb_mat,
                                                   extra_mat=self.glove_emb_mat, extra_trainable=self.finetune_emb,
                                                   scope='gene_token_emb_mat')
            emb = tf.nn.embedding_lookup(token_emb_mat, self.token_seq)  # bs,sl,tel
            self.tensor_dict['emb'] = emb

        rep = disan(
            emb, self.token_mask, 'DiSAN', cfg.dropout,
            self.is_train, cfg.wd, 'relu', tensor_dict=self.tensor_dict, name='')

        with tf.variable_scope('output'):
            pre_logits = tf.nn.relu(linear([rep], hn, True, scope='pre_logits_linear',
                                          wd=cfg.wd, input_keep_prob=cfg.dropout,
                                          is_train=self.is_train))  # bs, hn
            logits = linear([pre_logits], self.output_class, False, scope='get_output',
                            wd=cfg.wd, input_keep_prob=cfg.dropout, is_train=self.is_train) # bs, 5
        _logger.done()
        return logits
예제 #5
0
파일: dataset.py 프로젝트: zwjyyc/BiBloSA
    def digitize_data(self, data_list, dicts, dataset_type):
        token2index = dict([
            (token, idx)
            for idx, token in enumerate(dicts['token'] + dicts['glove'])
        ])
        char2index = dict([(token, idx)
                           for idx, token in enumerate(dicts['char'])])

        def digitize_token(token):
            token = token if not cfg.lower_word else token.lower()
            try:
                return token2index[token]
            except KeyError:
                return 1

        def digitize_char(char):
            try:
                return char2index[char]
            except KeyError:
                return 1

        _logger.add()
        _logger.add('digitizing data: %s...' % dataset_type)
        for sample in data_list:
            sample['token_digital'] = [
                digitize_token(token) for token in sample['token']
            ]
            sample['char_digital'] = [[
                digitize_char(char) for char in list(token)
            ] for token in sample['token']]
        _logger.done()
        return data_list
예제 #6
0
    def build_network(self):
        _logger.add()
        _logger.add('building %s neural network structure...' % cfg.network_type)
        tds, cds = self.tds, self.cds
        tl = self.tl
        tel, cel, cos, ocd, fh = self.tel, self.cel, self.cos, self.ocd, self.fh
        hn = self.hn
        bs = self.bs

        with tf.variable_scope('emb'):
            token_emb_mat = generate_embedding_mat(tds, tel, init_mat=self.token_emb_mat,
                                                   extra_mat=self.glove_emb_mat, extra_trainable=self.finetune_emb,
                                                   scope='gene_token_emb_mat')
            emb = tf.nn.embedding_lookup(token_emb_mat, self.token_seq)  # bs,sl1,tel

        with tf.variable_scope('sent_encoding'):
            rep = sentence_encoding_models(
                emb, self.token_mask, cfg.context_fusion_method, 'relu',
                'ct_based_sent2vec', cfg.wd, self.is_train, cfg.dropout,
                block_len=cfg.block_len)

        with tf.variable_scope('output'):
            pre_logits = tf.nn.relu(linear([rep], hn, True, scope='pre_logits_linear',
                                           wd=cfg.wd, input_keep_prob=cfg.dropout,
                                           is_train=self.is_train))  # bs, hn
            logits = linear([pre_logits], self.output_class, False, scope='get_output',
                            wd=cfg.wd, input_keep_prob=cfg.dropout, is_train=self.is_train) # bs, 5
        _logger.done()
        return logits
예제 #7
0
    def process_raw_data(self, dataset, data_type):
        def further_tokenize(temp_tokens):
            tokens = []  # [[(s,e),...],...]
            for token in temp_tokens:
                l = (
                "-", "\u2212", "\u2014", "\u2013", "/", "~", '"', "'", "\u201C", "\u2019", "\u201D", "\u2018", "\u00B0")
                tokens.extend(re.split("([{}])".format("".join(l)), token))
            return tokens

        # tokens
        _logger.add()
        _logger.add('processing raw data for %s' % data_type)

        for sample in tqdm(dataset):
            sample['sentence1_token'] = [node.token
                                         for node in sample['sentence1_binary_parse_node_list'] if node.is_leaf]
            sample['sentence1_tag'] = [node.tag
                                       for node in sample['sentence1_binary_parse_node_list'] if node.is_leaf]

            sample['sentence2_token'] = [node.token
                                         for node in sample['sentence2_binary_parse_node_list'] if node.is_leaf]
            sample['sentence2_tag'] = [node.tag
                                       for node in sample['sentence2_binary_parse_node_list'] if node.is_leaf]

            if cfg.data_clip_method == 'no_tree':
                sample['sentence1_token'] = further_tokenize(sample['sentence1_token'])
                sample['sentence2_token'] = further_tokenize(sample['sentence2_token'])
        _logger.done()
        return dataset
 def load_data_pickle(self, data_file_path, data_type):
     _logger.add()
     _logger.add('load file for %s' % data_type)
     dataset = None
     with open(data_file_path, 'rb', encoding='utf-8') as file:
         dataset = pickle.load(file)
     _logger.done()
     return dataset
 def load_data(self, data_file_path, data_type):
     _logger.add()
     _logger.add('load file for %s' % data_type)
     dataset = []
     with open(data_file_path, 'r', encoding='utf-8') as f:
         dataset = json.load(f)
     _logger.done()
     return dataset
예제 #10
0
    def transform_str_to_tree(self, dataset, data_type):
        _logger.add()
        _logger.add('transforming str format tree into real tree for %s' %
                    data_type)
        for sample in tqdm(dataset):
            sample['sentence1_binary_parse_tree'] = recursive_build_binary(
                tokenize_str_format_tree(sample['sentence1_binary_parse']))
            sample['sentence2_binary_parse_tree'] = recursive_build_binary(
                tokenize_str_format_tree(sample['sentence2_binary_parse']))
            # sample['sentence1_parse_tree'] = recursive_build_penn_format(
            #     tokenize_str_format_tree(sample['sentence1_parse']))
            # sample['sentence2_parse_tree'] = recursive_build_penn_format(
            #     tokenize_str_format_tree(sample['sentence2_parse']))

            # to node_list
            sample['sentence1_binary_parse_tree'], sample['sentence1_binary_parse_node_list'] = \
                transform_tree_to_parent_index(sample['sentence1_binary_parse_tree'])
            sample['sentence2_binary_parse_tree'], sample['sentence2_binary_parse_node_list'] = \
                transform_tree_to_parent_index(sample['sentence2_binary_parse_tree'])
            # sample['sentence1_parse_tree'], sample['sentence1_parse_node_list'] = \
            #     transform_tree_to_parent_index(sample['sentence1_parse_tree'])
            # sample['sentence2_parse_tree'], sample['sentence2_parse_node_list'] = \
            #     transform_tree_to_parent_index(sample['sentence2_parse_tree'])

            # shift reduce info
            # # s1
            s1_child_parent_node_indices = [
                (new_tree_node.node_index, new_tree_node.parent_index)
                for new_tree_node in sample['sentence1_binary_parse_node_list']
            ]
            s1_sr = shift_reduce_constituency_forest(
                s1_child_parent_node_indices)
            s1_op_list, s1_node_list_in_stack, s1_reduce_mat = zip(*s1_sr)
            s1_sr_info = {
                'op_list': s1_op_list,
                'reduce_mat': s1_reduce_mat,
                'node_list_in_stack': s1_node_list_in_stack
            }
            sample['s1_sr_info'] = s1_sr_info

            # # s2
            s2_child_parent_node_indices = [
                (new_tree_node.node_index, new_tree_node.parent_index)
                for new_tree_node in sample['sentence2_binary_parse_node_list']
            ]
            s2_sr = shift_reduce_constituency_forest(
                s2_child_parent_node_indices)
            s2_op_list, s2_node_list_in_stack, s2_reduce_mat = zip(*s2_sr)
            s2_sr_info = {
                'op_list': s2_op_list,
                'reduce_mat': s2_reduce_mat,
                'node_list_in_stack': s2_node_list_in_stack
            }
            sample['s2_sr_info'] = s2_sr_info

        _logger.done()
        return dataset
예제 #11
0
 def load_snli_data(self, data_path, data_type):
     _logger.add()
     _logger.add('load file for %s' % data_type)
     dataset = []
     with open(data_path, 'r', encoding='utf-8') as file:
         for line in file:
             json_obj = json.loads(line)
             dataset.append(json_obj)
     _logger.done()
     return dataset
예제 #12
0
 def process_raw_data(self, data_list, data_type):
     _logger.add()
     _logger.add('processing raw data: %s...' % data_type)
     for sample in data_list:
         for tree_node in sample:
             # node_index, parent_index, token_seq, leaf_node_index_seq, is_leaf, token, sentiment_label
             # char_seq
             tree_node['char_seq'] = [list(token) for token in tree_node['token_seq']]
     _logger.done()
     return data_list
예제 #13
0
    def count_data_and_build_dict(self, data_list, gene_dicts=True):
        def add_ept_and_unk(a_list):
            a_list.insert(0, '@@@empty')
            a_list.insert(1, '@@@unk')
            return a_list

        _logger.add()
        _logger.add('counting and build dictionaries')

        token_collection = []
        char_collection = []

        sent_len_collection = []
        token_len_collection = []

        for sample in data_list:
            for tree_node in sample:
                token_collection += tree_node['token_seq']
                sent_len_collection.append(len(tree_node['token_seq']))
                for char_seq in tree_node['char_seq']:
                    char_collection += char_seq
                    token_len_collection.append(len(char_seq))

        max_sent_len = dynamic_length(sent_len_collection, 1, security=False)[0]
        max_token_len = dynamic_length(token_len_collection, 0.99, security=False)[0]

        if gene_dicts:
            # token & char
            tokenSet = dynamic_keep(token_collection, 1)
            charSet = dynamic_keep(char_collection, 1)
            if cfg.use_glove_unk_token:
                gloveData = load_glove(cfg.word_embedding_length)
                gloveTokenSet = list(gloveData.keys())
                if cfg.lower_word:
                    tokenSet = list(set([token.lower() for token in tokenSet]))  ##!!!
                    gloveTokenSet = list(set([token.lower() for token in gloveTokenSet]))  ##!!!

                # delete token from gloveTokenSet which appears in tokenSet
                for token in tokenSet:
                    try:
                        gloveTokenSet.remove(token)
                    except ValueError:
                        pass
            else:
                if cfg.lower_word:
                    tokenSet = list(set([token.lower() for token in tokenSet]))
                gloveTokenSet = []
            tokenSet = add_ept_and_unk(tokenSet)
            charSet = add_ept_and_unk(charSet)
            dicts = {'token': tokenSet, 'char': charSet, 'glove': gloveTokenSet}
        else:
            dicts = {}

        _logger.done()
        return dicts, {'sent': max_sent_len, 'token': max_token_len}
예제 #14
0
파일: dataset.py 프로젝트: zwjyyc/BiBloSA
 def load_question_classification_data(self, data_file_path, data_type):
     _logger.add()
     _logger.add('load file for %s' % data_type)
     dataset = []
     with open(data_file_path, 'r', encoding='latin-1') as file:
         for line in file:
             line_split = line.strip().split(' ')
             cls = line_split[0].split(':')[0]
             sub_cls = line_split[0]
             token = line_split[1:]
             sample = {'token': token, 'cls': cls, 'sub_cls': sub_cls}
             dataset.append(sample)
     _logger.done()
     return dataset
예제 #15
0
    def clip_filter_data(self, data_list, data_clip_method, data_type):
        _logger.add()
        _logger.add('%s cliping data for  %s...' %
                    (data_clip_method, data_type))

        for sample in data_list:
            if data_clip_method == 'no_tree':
                sample.pop('sentence1_parse')
                sample.pop('sentence2_parse')
                # sample.pop('sentence1_parse_tree')
                # sample.pop('sentence2_parse_tree')
                # sample.pop('sentence1_parse_node_list')
                # sample.pop('sentence2_parse_node_list')
                sample.pop('sentence1_binary_parse')
                sample.pop('sentence2_binary_parse')
                sample.pop('sentence1_binary_parse_tree')
                sample.pop('sentence2_binary_parse_tree')
                sample.pop('sentence1_binary_parse_node_list')
                sample.pop('sentence2_binary_parse_node_list')
                sample.pop('s1_sr_info')
                sample.pop('s2_sr_info')
                # sample.pop('s1_tree_tag')
                # sample.pop('s2_tree_tag')
            elif data_clip_method == 'no_redundancy':
                sample.pop('sentence1_parse')
                sample.pop('sentence2_parse')
                # sample.pop('sentence1_parse_tree')
                # sample.pop('sentence2_parse_tree')
                # sample.pop('sentence1_parse_node_list')
                # sample.pop('sentence2_parse_node_list')

                sample.pop('sentence1_binary_parse')
                sample.pop('sentence2_binary_parse')
                sample.pop('sentence1_binary_parse_tree')
                sample.pop('sentence2_binary_parse_tree')

                for node in sample['sentence1_binary_parse_node_list']:
                    node.children_nodes = None
                    node.leaf_node_index_seq = None

                for node in sample['sentence1_binary_parse_node_list']:
                    node.children_nodes = None
                    node.leaf_node_index_seq = None

            else:
                raise AttributeError('no data clip method named as %s' %
                                     data_clip_method)
        _logger.done()
        return data_list
예제 #16
0
    def restore(self, sess):
        _logger.add()

        if cfg.load_path is not None:
            _logger.add('trying to restore from ckpt file %s' % cfg.load_path)
            try:
                self.saver.restore(sess, cfg.load_path)
                _logger.add('success to restore')
            except tf.errors.NotFoundError:
                _logger.add('failure to restore')
                if cfg.mode != 'train': raise FileNotFoundError('cannot find model file')
        else:
            _logger.add('No check point file')
            if cfg.mode != 'train': raise FileNotFoundError('cannot find model file')

        _logger.done()
예제 #17
0
    def restore_part(self, sess):

        _logger.add()
        # print(cfg.ckpt_dir)
        all_variable = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.model.scope)
        trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.model.scope)
        load_var_list = trainable_vars

        variables_to_restore = [var for var in load_var_list
                                if not var.op.name.startswith(self.model.scope + '/hard_network')]
        private_saver = tf.train.Saver(variables_to_restore)

        load_path = cfg.load_path or cfg.default_pretrained_model_path

        if cfg.load_step is None:
            if load_path is None:
                _logger.add('trying to restore from dir %s' % cfg.ckpt_dir)
                latest_checkpoint_path = tf.train.latest_checkpoint(cfg.ckpt_dir)
            else:
                latest_checkpoint_path = load_path
        else:
            latest_checkpoint_path = cfg.ckpt_path+'-'+str(cfg.load_step)

        if latest_checkpoint_path is not None:
            _logger.add('trying to restore from ckpt file %s' % latest_checkpoint_path)
            try:
                private_saver.restore(sess, latest_checkpoint_path)
                _logger.add('success to restore')
            except tf.errors.NotFoundError:
                _logger.add('failure to restore')
                if cfg.mode != 'train': raise FileNotFoundError('canot find model file')
        else:
            _logger.add('No check point file in dir %s '% cfg.ckpt_dir)
            if cfg.mode != 'train': raise FileNotFoundError('canot find model file')

        _logger.done()
예제 #18
0
파일: dataset.py 프로젝트: zwjyyc/BiBloSA
    def divide_data_into_shared_data(dataset, data_type):
        _logger.add()
        _logger.add('dividing data in to shared data: %s' % data_type)

        shared_data = []
        nn_data = []

        shared_idx = 0
        for topic in tqdm(dataset):
            for paragraph in topic['paragraphs']:
                shared_data.append({
                    'context':
                    paragraph['context'],
                    'context_token':
                    paragraph['context_token'],
                    'context_token_digital':
                    paragraph['context_token_digital'],
                })
                for qa in paragraph['qas']:
                    qa['shared_index'] = shared_idx
                    nn_data.append(qa)
                shared_idx += 1
        _logger.done()
        return shared_data, nn_data
예제 #19
0
 def add_summary(self, summary, global_step):
     _logger.add()
     _logger.add('saving summary...')
     self.writer.add_summary(summary, global_step)
     _logger.done()
예제 #20
0
파일: m_mtsa.py 프로젝트: taoshen58/mtsa
    def build_network(self):
        _logger.add()
        _logger.add('building %s neural network structure...' %
                    cfg.network_type)
        tds, cds = self.tds, self.cds
        tl = self.tl
        tel, cel, cos, ocd, fh = self.tel, self.cel, self.cos, self.ocd, self.fh
        hn = self.hn
        bs, sl, ol, mc = self.bs, self.sl, self.ol, self.mc

        with tf.variable_scope('emb'):
            token_emb_mat = generate_embedding_mat(
                tds,
                tel,
                init_mat=self.token_emb_mat,
                extra_mat=self.glove_emb_mat,
                extra_trainable=self.finetune_emb,
                scope='gene_token_emb_mat')
            emb = tf.nn.embedding_lookup(token_emb_mat,
                                         self.token_seq)  # bs,sl,tel
            self.tensor_dict['emb'] = emb

        with tf.variable_scope('sent_encoding'):
            act_name = 'relu'
            seq_rep = multi_mask_tensorized_self_attn(
                emb,
                self.token_mask,
                hn=2 * hn,
                head_num=2,
                is_train=self.is_train,
                attn_keep_prob=1.,
                dense_keep_prob=cfg.dropout,
                wd=cfg.wd,
                use_direction=True,
                attn_self=False,
                use_fusion_gate=True,
                final_mask_ft=None,
                dot_activation_name='sigmoid',
                use_input_for_attn=False,
                add_layer_for_multi=True,
                activation_func_name=act_name,
                apply_act_for_v=True,
                input_hn=None,
                output_hn=None,
                accelerate=False,
                merge_var=False,
                scope='proposed_model')

            rep = multi_dim_souce2token_self_attn(seq_rep, self.token_mask,
                                                  's2t_self_attn', cfg.dropout,
                                                  self.is_train, cfg.wd,
                                                  act_name)

        with tf.variable_scope('output'):
            pre_logits = tf.nn.relu(
                linear([rep],
                       hn,
                       True,
                       scope='pre_logits_linear',
                       wd=cfg.wd,
                       input_keep_prob=cfg.dropout,
                       is_train=self.is_train))  # bs, hn
            logits = linear([pre_logits],
                            self.output_class,
                            False,
                            scope='get_output',
                            wd=cfg.wd,
                            input_keep_prob=cfg.dropout,
                            is_train=self.is_train)  # bs, 5
        _logger.done()
        return logits
예제 #21
0
 def save(self, sess, global_step=None):
     _logger.add()
     _logger.add('saving model to %s' % cfg.ckpt_path)
     self.saver.save(sess, cfg.ckpt_path, global_step)
     _logger.done()
예제 #22
0
파일: dataset.py 프로젝트: zwjyyc/BiBloSA
    def count_data_and_build_dict(dataset, sent_len_rate, gene_dicts=True):
        def add_ept_and_unk(a_list):
            a_list.insert(0, '@@@empty')
            a_list.insert(1, '@@@unk')
            return a_list

        _logger.add()
        _logger.add('counting and build dictionaries')

        token_collection = []
        sent_num_collection = []
        sent_len_collection = []
        question_len_collection = []

        for topic in dataset:
            for paragraph in topic['paragraphs']:
                sent_num_collection.append(len(paragraph['context_token']))
                for sent_token in paragraph['context_token']:
                    sent_len_collection.append(len(sent_token))
                    token_collection += sent_token
                for qa in paragraph['qas']:
                    question_len_collection.append(len(qa['question_token']))
                    token_collection += qa['question_token']

        _logger.done()

        max_sent_num, _ = dynamic_length(sent_num_collection, 1.)
        max_sent_len, _ = dynamic_length(sent_len_collection, sent_len_rate)
        max_question_len, _ = dynamic_length(question_len_collection, 0.995)

        if gene_dicts:
            tokenSet = dynamic_keep(token_collection, 0.995)
            if cfg.use_glove_unk_token:
                gloveData = load_glove(cfg.word_embedding_length)
                gloveTokenSet = list(gloveData.keys())
                if cfg.lower_word:
                    tokenSet = list(set([token.lower()
                                         for token in tokenSet]))  ##!!!
                    gloveTokenSet = list(
                        set([token.lower() for token in gloveTokenSet]))  ##!!!

                # delete token from gloveTokenSet which appears in tokenSet
                for token in tokenSet:
                    try:
                        gloveTokenSet.remove(token)
                    except ValueError:
                        pass
            else:
                if cfg.lower_word:
                    tokenSet = list(set([token.lower() for token in tokenSet]))
                gloveTokenSet = []
            tokenSet = add_ept_and_unk(tokenSet)
            dicts = {'token': tokenSet, 'glove': gloveTokenSet}
        else:
            dicts = {}
        _logger.done()
        return dicts, {
            'sent_num': max_sent_num,
            'sent_len': max_sent_len,
            'question': max_question_len
        }
예제 #23
0
    def build_network(self):
        _logger.add()
        _logger.add('building %s neural network structure...' %
                    cfg.network_type)
        tds, cds = self.tds, self.cds
        tl = self.tl
        tel, cel, cos, ocd, fh = self.tel, self.cel, self.cos, self.ocd, self.fh
        hn = self.hn
        bs, sl, ol, mc = self.bs, self.sl, self.ol, self.mc

        with tf.variable_scope('emb'):
            token_emb_mat = generate_embedding_mat(
                tds,
                tel,
                init_mat=self.token_emb_mat,
                extra_mat=self.glove_emb_mat,
                extra_trainable=self.finetune_emb,
                scope='gene_token_emb_mat')
            emb = tf.nn.embedding_lookup(token_emb_mat,
                                         self.token_seq)  # bs,sl,tel
            self.tensor_dict['emb'] = emb

        with tf.variable_scope('ct_attn'):
            rep_fw = directional_attention_with_dense(
                emb,
                self.token_mask,
                'forward',
                'dir_attn_fw',
                cfg.dropout,
                self.is_train,
                cfg.wd,
                'relu',
                tensor_dict=self.tensor_dict,
                name='fw_attn')
            rep_bw = directional_attention_with_dense(
                emb,
                self.token_mask,
                'backward',
                'dir_attn_bw',
                cfg.dropout,
                self.is_train,
                cfg.wd,
                'relu',
                tensor_dict=self.tensor_dict,
                name='bw_attn')

            seq_rep = tf.concat([rep_fw, rep_bw], -1)

        with tf.variable_scope('sent_enc_attn'):
            rep = multi_dimensional_attention(seq_rep,
                                              self.token_mask,
                                              'multi_dimensional_attention',
                                              cfg.dropout,
                                              self.is_train,
                                              cfg.wd,
                                              'relu',
                                              tensor_dict=self.tensor_dict,
                                              name='attn')

        with tf.variable_scope('output'):
            pre_logits = tf.nn.relu(
                linear([rep],
                       hn,
                       True,
                       scope='pre_logits_linear',
                       wd=cfg.wd,
                       input_keep_prob=cfg.dropout,
                       is_train=self.is_train))  # bs, hn
            logits = linear([pre_logits],
                            self.output_class,
                            False,
                            scope='get_output',
                            wd=cfg.wd,
                            input_keep_prob=cfg.dropout,
                            is_train=self.is_train)  # bs, 5
        _logger.done()
        return logits