コード例 #1
0
    def __init__(self, word2id, gram2id, feature2id, labelmap, processor,
                 hpara, args):
        super().__init__()
        self.spec = locals()
        self.spec.pop("self")
        self.spec.pop("__class__")
        self.spec.pop('args')
        self.word2id = word2id

        self.hpara = hpara
        self.max_seq_length = self.hpara['max_seq_length']
        self.max_ngram_size = self.hpara['max_ngram_size']
        self.use_attention = self.hpara['use_attention']

        self.gram2id = gram2id
        self.feature2id = feature2id
        self.feature_processor = processor

        if self.hpara['use_attention']:
            self.source = self.hpara['source']
            self.feature_flag = self.hpara['feature_flag']
        else:
            self.source = None
            self.feature_flag = None

        self.labelmap = labelmap
        self.num_labels = len(self.labelmap) + 1

        self.bert_tokenizer = None
        self.bert = None
        self.zen_tokenizer = None
        self.zen = None
        self.zen_ngram_dict = None

        if self.hpara['use_bert']:
            if args.do_train:
                cache_dir = args.cache_dir if args.cache_dir else os.path.join(
                    str(PYTORCH_PRETRAINED_BERT_CACHE),
                    'distributed_{}'.format(args.local_rank))
                self.bert_tokenizer = BertTokenizer.from_pretrained(
                    args.bert_model, do_lower_case=self.hpara['do_lower_case'])
                self.bert = BertModel.from_pretrained(args.bert_model,
                                                      cache_dir=cache_dir)
                self.hpara['bert_tokenizer'] = self.bert_tokenizer
                self.hpara['config'] = self.bert.config
            else:
                self.bert_tokenizer = self.hpara['bert_tokenizer']
                self.bert = BertModel(self.hpara['config'])
            hidden_size = self.bert.config.hidden_size
            self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
        elif self.hpara['use_zen']:
            if args.do_train:
                cache_dir = args.cache_dir if args.cache_dir else os.path.join(
                    str(zen.PYTORCH_PRETRAINED_BERT_CACHE),
                    'distributed_{}'.format(args.local_rank))
                self.zen_tokenizer = zen.BertTokenizer.from_pretrained(
                    args.bert_model, do_lower_case=self.hpara['do_lower_case'])
                self.zen_ngram_dict = zen.ZenNgramDict(
                    args.bert_model, tokenizer=self.zen_tokenizer)
                self.zen = zen.modeling.ZenModel.from_pretrained(
                    args.bert_model, cache_dir=cache_dir)
                self.hpara['zen_tokenizer'] = self.zen_tokenizer
                self.hpara['zen_ngram_dict'] = self.zen_ngram_dict
                self.hpara['config'] = self.zen.config
            else:
                self.zen_tokenizer = self.hpara['zen_tokenizer']
                self.zen_ngram_dict = self.hpara['zen_ngram_dict']
                self.zen = zen.modeling.ZenModel(self.hpara['config'])
            hidden_size = self.zen.config.hidden_size
            self.dropout = nn.Dropout(self.zen.config.hidden_dropout_prob)
        else:
            raise ValueError()

        if self.hpara['use_attention']:
            self.context_attention = Attention(hidden_size, len(self.gram2id))
            self.feature_attention = Attention(hidden_size,
                                               len(self.feature2id))
            self.classifier = nn.Linear(hidden_size * 3,
                                        self.num_labels,
                                        bias=False)
        else:
            self.context_attention = None
            self.feature_attention = None
            self.classifier = nn.Linear(hidden_size,
                                        self.num_labels,
                                        bias=False)

        self.crf = CRF(tagset_size=self.num_labels - 3, gpu=True)

        if args.do_train:
            self.spec['hpara'] = self.hpara
コード例 #2
0
class TwASP(nn.Module):
    def __init__(self, word2id, gram2id, feature2id, labelmap, processor,
                 hpara, args):
        super().__init__()
        self.spec = locals()
        self.spec.pop("self")
        self.spec.pop("__class__")
        self.spec.pop('args')
        self.word2id = word2id

        self.hpara = hpara
        self.max_seq_length = self.hpara['max_seq_length']
        self.max_ngram_size = self.hpara['max_ngram_size']
        self.use_attention = self.hpara['use_attention']

        self.gram2id = gram2id
        self.feature2id = feature2id
        self.feature_processor = processor

        if self.hpara['use_attention']:
            self.source = self.hpara['source']
            self.feature_flag = self.hpara['feature_flag']
        else:
            self.source = None
            self.feature_flag = None

        self.labelmap = labelmap
        self.num_labels = len(self.labelmap) + 1

        self.bert_tokenizer = None
        self.bert = None
        self.zen_tokenizer = None
        self.zen = None
        self.zen_ngram_dict = None

        if self.hpara['use_bert']:
            if args.do_train:
                cache_dir = args.cache_dir if args.cache_dir else os.path.join(
                    str(PYTORCH_PRETRAINED_BERT_CACHE),
                    'distributed_{}'.format(args.local_rank))
                self.bert_tokenizer = BertTokenizer.from_pretrained(
                    args.bert_model, do_lower_case=self.hpara['do_lower_case'])
                self.bert = BertModel.from_pretrained(args.bert_model,
                                                      cache_dir=cache_dir)
                self.hpara['bert_tokenizer'] = self.bert_tokenizer
                self.hpara['config'] = self.bert.config
            else:
                self.bert_tokenizer = self.hpara['bert_tokenizer']
                self.bert = BertModel(self.hpara['config'])
            hidden_size = self.bert.config.hidden_size
            self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
        elif self.hpara['use_zen']:
            if args.do_train:
                cache_dir = args.cache_dir if args.cache_dir else os.path.join(
                    str(zen.PYTORCH_PRETRAINED_BERT_CACHE),
                    'distributed_{}'.format(args.local_rank))
                self.zen_tokenizer = zen.BertTokenizer.from_pretrained(
                    args.bert_model, do_lower_case=self.hpara['do_lower_case'])
                self.zen_ngram_dict = zen.ZenNgramDict(
                    args.bert_model, tokenizer=self.zen_tokenizer)
                self.zen = zen.modeling.ZenModel.from_pretrained(
                    args.bert_model, cache_dir=cache_dir)
                self.hpara['zen_tokenizer'] = self.zen_tokenizer
                self.hpara['zen_ngram_dict'] = self.zen_ngram_dict
                self.hpara['config'] = self.zen.config
            else:
                self.zen_tokenizer = self.hpara['zen_tokenizer']
                self.zen_ngram_dict = self.hpara['zen_ngram_dict']
                self.zen = zen.modeling.ZenModel(self.hpara['config'])
            hidden_size = self.zen.config.hidden_size
            self.dropout = nn.Dropout(self.zen.config.hidden_dropout_prob)
        else:
            raise ValueError()

        if self.hpara['use_attention']:
            self.context_attention = Attention(hidden_size, len(self.gram2id))
            self.feature_attention = Attention(hidden_size,
                                               len(self.feature2id))
            self.classifier = nn.Linear(hidden_size * 3,
                                        self.num_labels,
                                        bias=False)
        else:
            self.context_attention = None
            self.feature_attention = None
            self.classifier = nn.Linear(hidden_size,
                                        self.num_labels,
                                        bias=False)

        self.crf = CRF(tagset_size=self.num_labels - 3, gpu=True)

        if args.do_train:
            self.spec['hpara'] = self.hpara

    @staticmethod
    def init_hyper_parameters(args):
        hyper_parameters = DEFAULT_HPARA.copy()
        hyper_parameters['max_seq_length'] = args.max_seq_length
        hyper_parameters['max_ngram_size'] = args.max_ngram_size
        hyper_parameters['use_bert'] = args.use_bert
        hyper_parameters['use_zen'] = args.use_zen
        hyper_parameters['do_lower_case'] = args.do_lower_case
        hyper_parameters['use_attention'] = args.use_attention
        hyper_parameters['feature_flag'] = args.feature_flag
        hyper_parameters['source'] = args.source
        return hyper_parameters

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                labels=None,
                valid_ids=None,
                attention_mask_label=None,
                word_seq=None,
                feature_seq=None,
                word_matrix=None,
                feature_matrix=None,
                input_ngram_ids=None,
                ngram_position_matrix=None):

        if self.bert is not None:
            sequence_output, _ = self.bert(input_ids,
                                           token_type_ids,
                                           attention_mask,
                                           output_all_encoded_layers=False)
        elif self.zen is not None:
            sequence_output, _ = self.zen(
                input_ids,
                input_ngram_ids=input_ngram_ids,
                ngram_position_matrix=ngram_position_matrix,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                output_all_encoded_layers=False)
        else:
            raise ValueError()

        if self.context_attention is not None:
            word_attention = self.context_attention(word_seq, sequence_output,
                                                    word_matrix)
            feature_attention = self.feature_attention(feature_seq,
                                                       sequence_output,
                                                       feature_matrix)
            conc = torch.cat(
                [sequence_output, word_attention, feature_attention], dim=2)
        else:
            conc = sequence_output

        conc = self.dropout(conc)

        logits = self.classifier(conc)

        total_loss = self.crf.neg_log_likelihood_loss(logits, attention_mask,
                                                      labels)
        scores, tag_seq = self.crf._viterbi_decode(logits, attention_mask)

        return total_loss, tag_seq

    @property
    def model(self):
        return self.state_dict()

    @classmethod
    def from_spec(cls, spec, model, args):
        spec = spec.copy()
        res = cls(args=args, **spec)
        res.load_state_dict(model)
        return res

    def load_data(self, data_path, do_predict=False):

        if do_predict:
            lines = read_sentence(data_path)
        else:
            lines = readfile(data_path)

        flag = data_path[data_path.rfind('/') + 1:data_path.rfind('.')]

        data = []
        if self.feature_flag is None:
            for sentence, label in lines:
                data.append((sentence, label, None, None, None, None))

        elif self.feature_flag == 'pos':
            all_feature_data = self.feature_processor.read_features(data_path,
                                                                    flag=flag)
            for (sentence, label), feature_list in zip(lines,
                                                       all_feature_data):
                word_list = []
                syn_feature_list = []
                word_matching_position = []
                syn_matching_position = []
                for token_index, token in enumerate(feature_list):
                    current_token_pos = token['pos']
                    current_token = token['word']
                    current_feature = current_token + '_' + current_token_pos
                    if current_token not in self.gram2id:
                        current_token = '<UNK>'
                    if current_feature not in self.feature2id:
                        if current_token_pos not in self.feature2id:
                            current_feature = '<UNK>'
                        else:
                            current_feature = current_token_pos
                    word_list.append(current_token)
                    syn_feature_list.append(current_feature)

                    assert current_token in self.gram2id
                    assert current_feature in self.feature2id

                    char_index_list = token['char_index']
                    begin_char_index = max(char_index_list[0] - 2, 0)
                    end_char_index = min(char_index_list[-1] + 3,
                                         len(sentence))
                    for i in range(begin_char_index, end_char_index):
                        word_matching_position.append((i, token_index))
                        syn_matching_position.append((i, token_index))
                data.append((sentence, label, word_list, syn_feature_list,
                             word_matching_position, syn_matching_position))
        elif self.feature_flag == 'chunk':
            all_feature_data = self.feature_processor.read_features(data_path,
                                                                    flag=flag)
            for (sentence, label), feature_list in zip(lines,
                                                       all_feature_data):
                word_list = []
                syn_feature_list = []
                word_matching_position = []
                syn_matching_position = []
                for token_index, token in enumerate(feature_list):
                    current_token_chunk_tag = token['chunk_tags'][-1][
                        'chunk_tag']
                    assert token['chunk_tags'][-1]['height'] == 1
                    current_token = token['word']
                    current_feature = current_token + '_' + current_token_chunk_tag
                    if current_token not in self.gram2id:
                        current_token = '<UNK>'
                    if current_feature not in self.feature2id:
                        if current_token_chunk_tag not in self.feature2id:
                            current_feature = '<UNK>'
                        else:
                            current_feature = current_token_chunk_tag
                    word_list.append(current_token)
                    syn_feature_list.append(current_feature)

                    assert current_token in self.gram2id
                    assert current_feature in self.feature2id

                    token_index_range = token['chunk_tags'][-1]['range']
                    char_index_list = token['char_index']

                    for i in char_index_list:
                        for j in range(token_index_range[0],
                                       token_index_range[1]):
                            word_matching_position.append((i, j))
                            syn_matching_position.append((i, j))
                    word_matching_position = list(set(word_matching_position))
                    syn_matching_position = list(set(syn_matching_position))
                data.append((sentence, label, word_list, syn_feature_list,
                             word_matching_position, syn_matching_position))
        elif self.feature_flag == 'dep':
            all_feature_data = self.feature_processor.read_features(data_path,
                                                                    flag=flag)
            for (sentence, label), feature_list in zip(lines,
                                                       all_feature_data):
                word_list = []
                syn_feature_list = []
                word_matching_position = []
                syn_matching_position = []
                for token_index, token in enumerate(feature_list):
                    current_token_dep_tag = token['dep']
                    current_token = token['word']
                    current_feature = current_token + '_' + current_token_dep_tag
                    if current_token not in self.gram2id:
                        current_token = '<UNK>'
                    if current_feature not in self.feature2id:
                        if current_token_dep_tag not in self.feature2id:
                            current_feature = '<UNK>'
                        else:
                            current_feature = current_token_dep_tag
                    word_list.append(current_token)
                    syn_feature_list.append(current_feature)

                    assert current_token in self.gram2id
                    assert current_feature in self.feature2id

                    if token['governed_index'] < 0:
                        token_index_list = [token_index]
                        char_index_list = token['char_index']
                    else:
                        governed_index = token['governed_index']
                        token_index_list = [token_index, governed_index]
                        governed_token = feature_list[governed_index]
                        char_index_list = token['char_index'] + governed_token[
                            'char_index']

                    for i in char_index_list:
                        for j in token_index_list:
                            word_matching_position.append((i, j))
                            syn_matching_position.append((i, j))
                    word_matching_position = list(set(word_matching_position))
                    syn_matching_position = list(set(syn_matching_position))
                data.append((sentence, label, word_list, syn_feature_list,
                             word_matching_position, syn_matching_position))
        else:
            raise ValueError()

        examples = []
        for i, (sentence, label, word_list, syn_feature_list,
                word_matching_position,
                syn_matching_position) in enumerate(data):
            guid = "%s-%s" % (flag, i)
            text_a = ' '.join(sentence)
            text_b = None
            if word_list is not None:
                word = ' '.join(word_list)
                word_list_len = len(word_list)
            else:
                word = None
                word_list_len = 0
            label = label
            examples.append(
                InputExample(guid=guid,
                             text_a=text_a,
                             text_b=text_b,
                             label=label,
                             word=word,
                             syn_feature=syn_feature_list,
                             word_matrix=word_matching_position,
                             syn_matrix=syn_matching_position,
                             sent_len=len(sentence),
                             word_list_len=word_list_len))
        return examples

    def convert_examples_to_features(self, examples):

        max_seq_length = min(
            int(max([e.sent_len for e in examples]) * 1.1 + 2),
            self.max_seq_length)

        if self.use_attention:
            max_ngram_size = max(
                min(max([e.word_list_len for e in examples]),
                    self.max_ngram_size), 1)

        features = []
        tokenizer = self.bert_tokenizer if self.bert_tokenizer is not None else self.zen_tokenizer

        for (ex_index, example) in enumerate(examples):
            textlist = example.text_a.split(' ')
            labellist = example.label
            tokens = []
            labels = []
            valid = []
            label_mask = []

            for i, word in enumerate(textlist):
                token = tokenizer.tokenize(word)
                tokens.extend(token)
                label_1 = labellist[i]
                for m in range(len(token)):
                    if m == 0:
                        labels.append(label_1)
                        valid.append(1)
                        label_mask.append(1)
                    else:
                        valid.append(0)
            if len(tokens) >= max_seq_length - 1:
                tokens = tokens[0:(max_seq_length - 2)]
                labels = labels[0:(max_seq_length - 2)]
                valid = valid[0:(max_seq_length - 2)]
                label_mask = label_mask[0:(max_seq_length - 2)]

            ntokens = []
            segment_ids = []
            label_ids = []

            ntokens.append("[CLS]")
            segment_ids.append(0)

            valid.insert(0, 1)
            label_mask.insert(0, 1)
            label_ids.append(self.labelmap["[CLS]"])
            for i, token in enumerate(tokens):
                ntokens.append(token)
                segment_ids.append(0)
                if len(labels) > i:
                    if labels[i] in self.labelmap:
                        label_ids.append(self.labelmap[labels[i]])
                    else:
                        label_ids.append(self.labelmap['<UNK>'])
            ntokens.append("[SEP]")

            segment_ids.append(0)
            valid.append(1)
            label_mask.append(1)
            label_ids.append(self.labelmap["[SEP]"])

            input_ids = tokenizer.convert_tokens_to_ids(ntokens)
            input_mask = [1] * len(input_ids)
            label_mask = [1] * len(label_ids)
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)
                label_ids.append(0)
                valid.append(1)
                label_mask.append(0)
            while len(label_ids) < max_seq_length:
                label_ids.append(0)
                label_mask.append(0)
            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length
            assert len(label_ids) == max_seq_length
            assert len(valid) == max_seq_length
            assert len(label_mask) == max_seq_length

            if self.use_attention:
                wordlist = example.word
                wordlist = wordlist.split(' ') if len(wordlist) > 0 else []
                syn_features = example.syn_feature
                word_matching_position = example.word_matrix
                syn_matching_position = example.syn_matrix
                word_ids = []
                feature_ids = []
                word_matching_matrix = np.zeros(
                    (max_seq_length, max_ngram_size), dtype=np.int)
                syn_matching_matrix = np.zeros(
                    (max_seq_length, max_ngram_size), dtype=np.int)

                if len(wordlist) > max_ngram_size:
                    wordlist = wordlist[:max_ngram_size]
                    syn_features = syn_features[:max_ngram_size]

                for word in wordlist:
                    if word == '':
                        continue
                    try:
                        word_ids.append(self.gram2id[word])
                    except KeyError:
                        print(word)
                        print(wordlist)
                        print(textlist)
                        raise KeyError()
                for feature in syn_features:
                    feature_ids.append(self.feature2id[feature])

                while len(word_ids) < max_ngram_size:
                    word_ids.append(0)
                    feature_ids.append(0)
                for position in word_matching_position:
                    char_p = position[0] + 1
                    word_p = position[1]
                    if char_p > max_seq_length - 2 or word_p > max_ngram_size - 1:
                        continue
                    else:
                        word_matching_matrix[char_p][word_p] = 1

                for position in syn_matching_position:
                    char_p = position[0] + 1
                    word_p = position[1]
                    if char_p > max_seq_length - 2 or word_p > max_ngram_size - 1:
                        continue
                    else:
                        syn_matching_matrix[char_p][word_p] = 1

                assert len(word_ids) == max_ngram_size
                assert len(feature_ids) == max_ngram_size
            else:
                word_ids = None
                feature_ids = None
                word_matching_matrix = None
                syn_matching_matrix = None

            if self.zen_ngram_dict is not None:
                ngram_matches = []
                #  Filter the ngram segment from 2 to 7 to check whether there is a ngram
                for p in range(2, 8):
                    for q in range(0, len(tokens) - p + 1):
                        character_segment = tokens[q:q + p]
                        # j is the starting position of the ngram
                        # i is the length of the current ngram
                        character_segment = tuple(character_segment)
                        if character_segment in self.zen_ngram_dict.ngram_to_id_dict:
                            ngram_index = self.zen_ngram_dict.ngram_to_id_dict[
                                character_segment]
                            ngram_matches.append(
                                [ngram_index, q, p, character_segment])

                random.shuffle(ngram_matches)

                max_ngram_in_seq_proportion = math.ceil(
                    (len(tokens) / max_seq_length) *
                    self.zen_ngram_dict.max_ngram_in_seq)
                if len(ngram_matches) > max_ngram_in_seq_proportion:
                    ngram_matches = ngram_matches[:max_ngram_in_seq_proportion]

                ngram_ids = [ngram[0] for ngram in ngram_matches]
                ngram_positions = [ngram[1] for ngram in ngram_matches]
                ngram_lengths = [ngram[2] for ngram in ngram_matches]
                ngram_tuples = [ngram[3] for ngram in ngram_matches]
                ngram_seg_ids = [
                    0 if position < (len(tokens) + 2) else 1
                    for position in ngram_positions
                ]

                ngram_mask_array = np.zeros(
                    self.zen_ngram_dict.max_ngram_in_seq, dtype=np.bool)
                ngram_mask_array[:len(ngram_ids)] = 1

                # record the masked positions
                ngram_positions_matrix = np.zeros(
                    shape=(max_seq_length,
                           self.zen_ngram_dict.max_ngram_in_seq),
                    dtype=np.int32)
                for i in range(len(ngram_ids)):
                    ngram_positions_matrix[
                        ngram_positions[i]:ngram_positions[i] +
                        ngram_lengths[i], i] = 1.0

                # Zero-pad up to the max ngram in seq length.
                padding = [0] * (self.zen_ngram_dict.max_ngram_in_seq -
                                 len(ngram_ids))
                ngram_ids += padding
                ngram_lengths += padding
                ngram_seg_ids += padding
            else:
                ngram_ids = None
                ngram_positions_matrix = None
                ngram_lengths = None
                ngram_tuples = None
                ngram_seg_ids = None
                ngram_mask_array = None

            features.append(
                InputFeatures(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              label_id=label_ids,
                              valid_ids=valid,
                              label_mask=label_mask,
                              word_ids=word_ids,
                              syn_feature_ids=feature_ids,
                              word_matching_matrix=word_matching_matrix,
                              syn_matching_matrix=syn_matching_matrix,
                              ngram_ids=ngram_ids,
                              ngram_positions=ngram_positions_matrix,
                              ngram_lengths=ngram_lengths,
                              ngram_tuples=ngram_tuples,
                              ngram_seg_ids=ngram_seg_ids,
                              ngram_masks=ngram_mask_array))
        return features

    def feature2input(self, device, feature):
        all_input_ids = torch.tensor([f.input_ids for f in feature],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in feature],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in feature],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in feature],
                                     dtype=torch.long)
        all_valid_ids = torch.tensor([f.valid_ids for f in feature],
                                     dtype=torch.long)
        all_lmask_ids = torch.tensor([f.label_mask for f in feature],
                                     dtype=torch.long)

        input_ids = all_input_ids.to(device)
        input_mask = all_input_mask.to(device)
        segment_ids = all_segment_ids.to(device)
        label_ids = all_label_ids.to(device)
        valid_ids = all_valid_ids.to(device)
        l_mask = all_lmask_ids.to(device)
        if self.hpara['use_attention']:
            all_word_ids = torch.tensor([f.word_ids for f in feature],
                                        dtype=torch.long)
            all_feature_ids = torch.tensor(
                [f.syn_feature_ids for f in feature], dtype=torch.long)
            all_word_matching_matrix = torch.tensor(
                [f.word_matching_matrix for f in feature], dtype=torch.float)

            word_ids = all_word_ids.to(device)
            feature_ids = all_feature_ids.to(device)
            word_matching_matrix = all_word_matching_matrix.to(device)
        else:
            word_ids = None
            feature_ids = None
            word_matching_matrix = None
        if self.hpara['use_zen']:
            all_ngram_ids = torch.tensor([f.ngram_ids for f in feature],
                                         dtype=torch.long)
            all_ngram_positions = torch.tensor(
                [f.ngram_positions for f in feature], dtype=torch.long)
            # all_ngram_lengths = torch.tensor([f.ngram_lengths for f in train_features], dtype=torch.long)
            # all_ngram_seg_ids = torch.tensor([f.ngram_seg_ids for f in train_features], dtype=torch.long)
            # all_ngram_masks = torch.tensor([f.ngram_masks for f in train_features], dtype=torch.long)

            ngram_ids = all_ngram_ids.to(device)
            ngram_positions = all_ngram_positions.to(device)
        else:
            ngram_ids = None
            ngram_positions = None
        return feature_ids, input_ids, input_mask, l_mask, label_ids, ngram_ids, ngram_positions, segment_ids, valid_ids, word_ids, word_matching_matrix
コード例 #3
0
    def __init__(self,
                 word2id,
                 label2id,
                 hpara,
                 model_path,
                 department2id=None,
                 disease2id=None):
        super().__init__()

        self.word2id = word2id
        self.department2id = None
        self.disease2id = None
        self.label2id = label2id
        self.party2id = None
        self.hpara = hpara
        self.num_labels = len(self.label2id)
        self.max_seq_length = self.hpara['max_seq_length']
        self.use_memory = self.hpara['use_memory']
        self.use_department = self.hpara['use_department']
        self.use_party = self.hpara['use_party']
        self.use_disease = self.hpara['use_disease']
        self.decoder = self.hpara['decoder']
        self.lstm_hidden_size = self.hpara['lstm_hidden_size']
        self.max_dialog_length = self.hpara['max_dialog_length']

        self.bert_tokenizer = None
        self.bert = None
        self.zen_tokenizer = None
        self.zen = None
        self.zen_ngram_dict = None

        if self.hpara['use_bert']:
            self.bert_tokenizer = BertTokenizer.from_pretrained(
                model_path, do_lower_case=self.hpara['do_lower_case'])
            self.bert = BertModel.from_pretrained(model_path, cache_dir='')
            hidden_size = self.bert.config.hidden_size
            self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
        elif self.hpara['use_zen']:
            self.zen_tokenizer = zen.BertTokenizer.from_pretrained(
                model_path, do_lower_case=self.hpara['do_lower_case'])
            self.zen_ngram_dict = zen.ZenNgramDict(
                model_path, tokenizer=self.zen_tokenizer)
            self.zen = zen.modeling.ZenModel.from_pretrained(model_path,
                                                             cache_dir='')
            hidden_size = self.zen.config.hidden_size
            self.dropout = nn.Dropout(self.zen.config.hidden_dropout_prob)
        else:
            raise ValueError()

        ori_hidden_size = hidden_size

        if self.use_memory:
            self.memory = Memory(hidden_size, len(word2id))
            hidden_size = hidden_size * 2
        else:
            self.memory = None

        if self.use_party:
            self.party_embedding = nn.Embedding(5, ori_hidden_size)
            hidden_size += ori_hidden_size
            self.party2id = {'<PAD>': 0, '<UNK>': 1, 'P': 2, 'D': 3}
        else:
            self.party_embedding = None

        utterance_hidden_size = hidden_size
        if self.hpara['utterance_encoder'] == 'LSTM':
            self.utterance_encoder = nn.LSTM(input_size=hidden_size,
                                             hidden_size=self.lstm_hidden_size,
                                             bidirectional=False,
                                             batch_first=True)
            utterance_hidden_size = self.lstm_hidden_size
        elif self.hpara['utterance_encoder'] == 'biLSTM':
            self.utterance_encoder = nn.LSTM(input_size=hidden_size,
                                             hidden_size=self.lstm_hidden_size,
                                             bidirectional=True,
                                             batch_first=True)
            utterance_hidden_size = self.lstm_hidden_size * 2
        else:
            self.utterance_encoder = None

        if self.use_department:
            self.department_embedding = nn.Embedding(len(department2id),
                                                     utterance_hidden_size)
            self.department2id = department2id
        else:
            self.department_embedding = None

        if self.use_disease:
            self.disease_embedding = nn.Embedding(len(disease2id),
                                                  utterance_hidden_size)
            self.disease2id = disease2id
        else:
            self.disease_embedding = None

        if self.use_department and self.use_disease:
            utterance_hidden_size = utterance_hidden_size * 3
        elif self.use_department or self.use_disease:
            utterance_hidden_size = utterance_hidden_size * 2

        self.classifier = nn.Linear(utterance_hidden_size, self.num_labels)

        if self.decoder == 'softmax':
            self.loss_fct = CrossEntropyLoss(ignore_index=0)
        elif self.decoder == 'crf':
            self.crf = CRF(self.num_labels, batch_first=True)
        else:
            raise ValueError()
コード例 #4
0
ファイル: wmseg_model.py プロジェクト: yz-liu/WMSeg
class WMSeg(nn.Module):
    def __init__(self, word2id, gram2id, labelmap, hpara, args):
        super().__init__()
        self.spec = locals()
        self.spec.pop("self")
        self.spec.pop("__class__")
        self.spec.pop('args')

        self.word2id = word2id
        self.gram2id = gram2id
        self.labelmap = labelmap
        self.hpara = hpara
        self.num_labels = len(self.labelmap) + 1
        self.max_seq_length = self.hpara['max_seq_length']
        self.max_ngram_size = self.hpara['max_ngram_size']
        self.max_ngram_length = self.hpara['max_ngram_length']

        self.bert_tokenizer = None
        self.bert = None
        self.zen_tokenizer = None
        self.zen = None
        self.zen_ngram_dict = None

        if self.hpara['use_bert']:
            if args.do_train:
                cache_dir = args.cache_dir if args.cache_dir else os.path.join(
                    str(PYTORCH_PRETRAINED_BERT_CACHE),
                    'distributed_{}'.format(args.local_rank))
                self.bert_tokenizer = BertTokenizer.from_pretrained(
                    args.bert_model, do_lower_case=self.hpara['do_lower_case'])
                self.bert = BertModel.from_pretrained(args.bert_model,
                                                      cache_dir=cache_dir)
                self.hpara['bert_tokenizer'] = self.bert_tokenizer
                self.hpara['config'] = self.bert.config
            else:
                self.bert_tokenizer = self.hpara['bert_tokenizer']
                self.bert = BertModel(self.hpara['config'])
            hidden_size = self.bert.config.hidden_size
            self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
        elif self.hpara['use_zen']:
            if args.do_train:
                cache_dir = args.cache_dir if args.cache_dir else os.path.join(
                    str(zen.PYTORCH_PRETRAINED_BERT_CACHE),
                    'distributed_{}'.format(args.local_rank))
                self.zen_tokenizer = zen.BertTokenizer.from_pretrained(
                    args.bert_model, do_lower_case=self.hpara['do_lower_case'])
                self.zen_ngram_dict = zen.ZenNgramDict(
                    args.bert_model, tokenizer=self.zen_tokenizer)
                self.zen = zen.modeling.ZenModel.from_pretrained(
                    args.bert_model, cache_dir=cache_dir)
                self.hpara['zen_tokenizer'] = self.zen_tokenizer
                self.hpara['zen_ngram_dict'] = self.zen_ngram_dict
                self.hpara['config'] = self.zen.config
            else:
                self.zen_tokenizer = self.hpara['zen_tokenizer']
                self.zen_ngram_dict = self.hpara['zen_ngram_dict']
                self.zen = zen.modeling.ZenModel(self.hpara['config'])
            hidden_size = self.zen.config.hidden_size
            self.dropout = nn.Dropout(self.zen.config.hidden_dropout_prob)
        else:
            raise ValueError()

        if self.hpara['use_memory']:
            self.kv_memory = WordKVMN(hidden_size, len(gram2id))
        else:
            self.kv_memory = None

        self.classifier = nn.Linear(hidden_size, self.num_labels, bias=False)

        if self.hpara['decoder'] == 'crf':
            self.crf = CRF(tagset_size=self.num_labels - 3, gpu=True)
        else:
            self.crf = None

        if args.do_train:
            self.spec['hpara'] = self.hpara

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                labels=None,
                valid_ids=None,
                attention_mask_label=None,
                word_seq=None,
                label_value_matrix=None,
                word_mask=None,
                input_ngram_ids=None,
                ngram_position_matrix=None):

        if self.bert is not None:
            sequence_output, _ = self.bert(input_ids,
                                           token_type_ids,
                                           attention_mask,
                                           output_all_encoded_layers=False)
        elif self.zen is not None:
            sequence_output, _ = self.zen(
                input_ids,
                input_ngram_ids=input_ngram_ids,
                ngram_position_matrix=ngram_position_matrix,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                output_all_encoded_layers=False)
        else:
            raise ValueError()

        if self.kv_memory is not None:
            sequence_output = self.kv_memory(word_seq, sequence_output,
                                             label_value_matrix, word_mask)

        sequence_output = self.dropout(sequence_output)

        logits = self.classifier(sequence_output)

        if self.crf is not None:
            # crf = CRF(tagset_size=number_of_labels+1, gpu=True)
            total_loss = self.crf.neg_log_likelihood_loss(
                logits, attention_mask, labels)
            scores, tag_seq = self.crf._viterbi_decode(logits, attention_mask)
            # Only keep active parts of the loss
        else:
            loss_fct = CrossEntropyLoss(ignore_index=0)
            total_loss = loss_fct(logits.view(-1, self.num_labels),
                                  labels.view(-1))
            tag_seq = torch.argmax(F.log_softmax(logits, dim=2), dim=2)

        return total_loss, tag_seq

    @staticmethod
    def init_hyper_parameters(args):
        hyper_parameters = DEFAULT_HPARA.copy()
        hyper_parameters['max_seq_length'] = args.max_seq_length
        hyper_parameters['max_ngram_size'] = args.max_ngram_size
        hyper_parameters['max_ngram_length'] = args.max_ngram_length
        hyper_parameters['use_bert'] = args.use_bert
        hyper_parameters['use_zen'] = args.use_zen
        hyper_parameters['do_lower_case'] = args.do_lower_case
        hyper_parameters['use_memory'] = args.use_memory
        hyper_parameters['decoder'] = args.decoder
        return hyper_parameters

    @property
    def model(self):
        return self.state_dict()

    @classmethod
    def from_spec(cls, spec, model, args):
        spec = spec.copy()
        res = cls(args=args, **spec)
        res.load_state_dict(model)
        return res

    def load_data(self, data_path, do_predict=False):

        if not do_predict:
            flag = data_path[data_path.rfind('/') + 1:data_path.rfind('.')]
            lines = readfile(data_path, flag=flag)
        else:
            flag = 'predict'
            lines = readsentence(data_path)

        data = []
        for sentence, label in lines:
            if self.kv_memory is not None:
                word_list = []
                matching_position = []
                for i in range(len(sentence)):
                    for j in range(self.max_ngram_length):
                        if i + j > len(sentence):
                            break
                        word = ''.join(sentence[i:i + j + 1])
                        if word in self.gram2id:
                            try:
                                index = word_list.index(word)
                            except ValueError:
                                word_list.append(word)
                                index = len(word_list) - 1
                            word_len = len(word)
                            for k in range(j + 1):
                                if word_len == 1:
                                    l = 'S'
                                elif k == 0:
                                    l = 'B'
                                elif k == j:
                                    l = 'E'
                                else:
                                    l = 'I'
                                matching_position.append((i + k, index, l))
            else:
                word_list = None
                matching_position = None
            data.append((sentence, label, word_list, matching_position))

        examples = []
        for i, (sentence, label, word_list,
                matching_position) in enumerate(data):
            guid = "%s-%s" % (flag, i)
            text_a = ' '.join(sentence)
            text_b = None
            if word_list is not None:
                word = ' '.join(word_list)
            else:
                word = None
            label = label
            examples.append(
                InputExample(guid=guid,
                             text_a=text_a,
                             text_b=text_b,
                             label=label,
                             word=word,
                             matrix=matching_position))
        return examples

    def convert_examples_to_features(self, examples):

        max_seq_length = min(
            int(max([len(e.text_a.split(' ')) for e in examples]) * 1.1 + 2),
            self.max_seq_length)

        if self.kv_memory is not None:
            max_word_size = max(
                min(max([len(e.word.split(' ')) for e in examples]),
                    self.max_ngram_size), 1)

        features = []

        tokenizer = self.bert_tokenizer if self.bert_tokenizer is not None else self.zen_tokenizer

        for (ex_index, example) in enumerate(examples):
            textlist = example.text_a.split(' ')
            labellist = example.label
            tokens = []
            labels = []
            valid = []
            label_mask = []

            for i, word in enumerate(textlist):
                token = tokenizer.tokenize(word)
                tokens.extend(token)
                label_1 = labellist[i]
                for m in range(len(token)):
                    if m == 0:
                        valid.append(1)
                        labels.append(label_1)
                        label_mask.append(1)
                    else:
                        valid.append(0)

            if len(tokens) >= max_seq_length - 1:
                tokens = tokens[0:(max_seq_length - 2)]
                labels = labels[0:(max_seq_length - 2)]
                valid = valid[0:(max_seq_length - 2)]
                label_mask = label_mask[0:(max_seq_length - 2)]

            ntokens = []
            segment_ids = []
            label_ids = []

            ntokens.append("[CLS]")
            segment_ids.append(0)

            valid.insert(0, 1)
            label_mask.insert(0, 1)
            label_ids.append(self.labelmap["[CLS]"])
            for i, token in enumerate(tokens):
                ntokens.append(token)
                segment_ids.append(0)
                if len(labels) > i:
                    label_ids.append(self.labelmap[labels[i]])
            ntokens.append("[SEP]")

            segment_ids.append(0)
            valid.append(1)
            label_mask.append(1)
            label_ids.append(self.labelmap["[SEP]"])

            input_ids = tokenizer.convert_tokens_to_ids(ntokens)
            input_mask = [1] * len(input_ids)
            label_mask = [1] * len(label_ids)
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)
                label_ids.append(0)
                valid.append(1)
                label_mask.append(0)
            while len(label_ids) < max_seq_length:
                label_ids.append(0)
                label_mask.append(0)
            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length
            assert len(label_ids) == max_seq_length
            assert len(valid) == max_seq_length
            assert len(label_mask) == max_seq_length

            if self.kv_memory is not None:
                wordlist = example.word
                wordlist = wordlist.split(' ') if len(wordlist) > 0 else []
                matching_position = example.matrix
                word_ids = []
                matching_matrix = np.zeros((max_seq_length, max_word_size),
                                           dtype=np.int)
                if len(wordlist) > max_word_size:
                    wordlist = wordlist[:max_word_size]
                for word in wordlist:
                    try:
                        word_ids.append(self.gram2id[word])
                    except KeyError:
                        print(word)
                        print(wordlist)
                        print(textlist)
                        raise KeyError()
                while len(word_ids) < max_word_size:
                    word_ids.append(0)
                for position in matching_position:
                    char_p = position[0] + 1
                    word_p = position[1]
                    if char_p > max_seq_length - 2 or word_p > max_word_size - 1:
                        continue
                    else:
                        matching_matrix[char_p][word_p] = self.labelmap[
                            position[2]]

                assert len(word_ids) == max_word_size
            else:
                word_ids = None
                matching_matrix = None

            if self.zen_ngram_dict is not None:
                ngram_matches = []
                #  Filter the ngram segment from 2 to 7 to check whether there is a ngram
                for p in range(2, 8):
                    for q in range(0, len(tokens) - p + 1):
                        character_segment = tokens[q:q + p]
                        # j is the starting position of the ngram
                        # i is the length of the current ngram
                        character_segment = tuple(character_segment)
                        if character_segment in self.zen_ngram_dict.ngram_to_id_dict:
                            ngram_index = self.zen_ngram_dict.ngram_to_id_dict[
                                character_segment]
                            ngram_matches.append(
                                [ngram_index, q, p, character_segment])

                # random.shuffle(ngram_matches)
                ngram_matches = sorted(ngram_matches, key=lambda s: s[0])

                max_ngram_in_seq_proportion = math.ceil(
                    (len(tokens) / max_seq_length) *
                    self.zen_ngram_dict.max_ngram_in_seq)
                if len(ngram_matches) > max_ngram_in_seq_proportion:
                    ngram_matches = ngram_matches[:max_ngram_in_seq_proportion]

                ngram_ids = [ngram[0] for ngram in ngram_matches]
                ngram_positions = [ngram[1] for ngram in ngram_matches]
                ngram_lengths = [ngram[2] for ngram in ngram_matches]
                ngram_tuples = [ngram[3] for ngram in ngram_matches]
                ngram_seg_ids = [
                    0 if position < (len(tokens) + 2) else 1
                    for position in ngram_positions
                ]

                ngram_mask_array = np.zeros(
                    self.zen_ngram_dict.max_ngram_in_seq, dtype=np.bool)
                ngram_mask_array[:len(ngram_ids)] = 1

                # record the masked positions
                ngram_positions_matrix = np.zeros(
                    shape=(max_seq_length,
                           self.zen_ngram_dict.max_ngram_in_seq),
                    dtype=np.int32)
                for i in range(len(ngram_ids)):
                    ngram_positions_matrix[
                        ngram_positions[i]:ngram_positions[i] +
                        ngram_lengths[i], i] = 1.0

                # Zero-pad up to the max ngram in seq length.
                padding = [0] * (self.zen_ngram_dict.max_ngram_in_seq -
                                 len(ngram_ids))
                ngram_ids += padding
                ngram_lengths += padding
                ngram_seg_ids += padding
            else:
                ngram_ids = None
                ngram_positions_matrix = None
                ngram_lengths = None
                ngram_tuples = None
                ngram_seg_ids = None
                ngram_mask_array = None

            features.append(
                InputFeatures(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              label_id=label_ids,
                              valid_ids=valid,
                              label_mask=label_mask,
                              word_ids=word_ids,
                              matching_matrix=matching_matrix,
                              ngram_ids=ngram_ids,
                              ngram_positions=ngram_positions_matrix,
                              ngram_lengths=ngram_lengths,
                              ngram_tuples=ngram_tuples,
                              ngram_seg_ids=ngram_seg_ids,
                              ngram_masks=ngram_mask_array))
        return features

    def feature2input(self, device, feature):
        all_input_ids = torch.tensor([f.input_ids for f in feature],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in feature],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in feature],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in feature],
                                     dtype=torch.long)
        all_valid_ids = torch.tensor([f.valid_ids for f in feature],
                                     dtype=torch.long)
        all_lmask_ids = torch.tensor([f.label_mask for f in feature],
                                     dtype=torch.long)
        input_ids = all_input_ids.to(device)
        input_mask = all_input_mask.to(device)
        segment_ids = all_segment_ids.to(device)
        label_ids = all_label_ids.to(device)
        valid_ids = all_valid_ids.to(device)
        l_mask = all_lmask_ids.to(device)
        if self.hpara['use_memory']:
            all_word_ids = torch.tensor([f.word_ids for f in feature],
                                        dtype=torch.long)
            all_matching_matrix = torch.tensor(
                [f.matching_matrix for f in feature], dtype=torch.long)
            all_word_mask = torch.tensor([f.matching_matrix for f in feature],
                                         dtype=torch.float)

            word_ids = all_word_ids.to(device)
            matching_matrix = all_matching_matrix.to(device)
            word_mask = all_word_mask.to(device)
        else:
            word_ids = None
            matching_matrix = None
            word_mask = None
        if self.hpara['use_zen']:
            all_ngram_ids = torch.tensor([f.ngram_ids for f in feature],
                                         dtype=torch.long)
            all_ngram_positions = torch.tensor(
                [f.ngram_positions for f in feature], dtype=torch.long)
            # all_ngram_lengths = torch.tensor([f.ngram_lengths for f in train_features], dtype=torch.long)
            # all_ngram_seg_ids = torch.tensor([f.ngram_seg_ids for f in train_features], dtype=torch.long)
            # all_ngram_masks = torch.tensor([f.ngram_masks for f in train_features], dtype=torch.long)

            ngram_ids = all_ngram_ids.to(device)
            ngram_positions = all_ngram_positions.to(device)
        else:
            ngram_ids = None
            ngram_positions = None
        return input_ids, input_mask, l_mask, label_ids, matching_matrix, ngram_ids, ngram_positions, segment_ids, valid_ids, word_ids, word_mask
コード例 #5
0
class HET(nn.Module):
    def __init__(self,
                 word2id,
                 label2id,
                 hpara,
                 model_path,
                 department2id=None,
                 disease2id=None):
        super().__init__()

        self.word2id = word2id
        self.department2id = None
        self.disease2id = None
        self.label2id = label2id
        self.party2id = None
        self.hpara = hpara
        self.num_labels = len(self.label2id)
        self.max_seq_length = self.hpara['max_seq_length']
        self.use_memory = self.hpara['use_memory']
        self.use_department = self.hpara['use_department']
        self.use_party = self.hpara['use_party']
        self.use_disease = self.hpara['use_disease']
        self.decoder = self.hpara['decoder']
        self.lstm_hidden_size = self.hpara['lstm_hidden_size']
        self.max_dialog_length = self.hpara['max_dialog_length']

        self.bert_tokenizer = None
        self.bert = None
        self.zen_tokenizer = None
        self.zen = None
        self.zen_ngram_dict = None

        if self.hpara['use_bert']:
            self.bert_tokenizer = BertTokenizer.from_pretrained(
                model_path, do_lower_case=self.hpara['do_lower_case'])
            self.bert = BertModel.from_pretrained(model_path, cache_dir='')
            hidden_size = self.bert.config.hidden_size
            self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
        elif self.hpara['use_zen']:
            self.zen_tokenizer = zen.BertTokenizer.from_pretrained(
                model_path, do_lower_case=self.hpara['do_lower_case'])
            self.zen_ngram_dict = zen.ZenNgramDict(
                model_path, tokenizer=self.zen_tokenizer)
            self.zen = zen.modeling.ZenModel.from_pretrained(model_path,
                                                             cache_dir='')
            hidden_size = self.zen.config.hidden_size
            self.dropout = nn.Dropout(self.zen.config.hidden_dropout_prob)
        else:
            raise ValueError()

        ori_hidden_size = hidden_size

        if self.use_memory:
            self.memory = Memory(hidden_size, len(word2id))
            hidden_size = hidden_size * 2
        else:
            self.memory = None

        if self.use_party:
            self.party_embedding = nn.Embedding(5, ori_hidden_size)
            hidden_size += ori_hidden_size
            self.party2id = {'<PAD>': 0, '<UNK>': 1, 'P': 2, 'D': 3}
        else:
            self.party_embedding = None

        utterance_hidden_size = hidden_size
        if self.hpara['utterance_encoder'] == 'LSTM':
            self.utterance_encoder = nn.LSTM(input_size=hidden_size,
                                             hidden_size=self.lstm_hidden_size,
                                             bidirectional=False,
                                             batch_first=True)
            utterance_hidden_size = self.lstm_hidden_size
        elif self.hpara['utterance_encoder'] == 'biLSTM':
            self.utterance_encoder = nn.LSTM(input_size=hidden_size,
                                             hidden_size=self.lstm_hidden_size,
                                             bidirectional=True,
                                             batch_first=True)
            utterance_hidden_size = self.lstm_hidden_size * 2
        else:
            self.utterance_encoder = None

        if self.use_department:
            self.department_embedding = nn.Embedding(len(department2id),
                                                     utterance_hidden_size)
            self.department2id = department2id
        else:
            self.department_embedding = None

        if self.use_disease:
            self.disease_embedding = nn.Embedding(len(disease2id),
                                                  utterance_hidden_size)
            self.disease2id = disease2id
        else:
            self.disease_embedding = None

        if self.use_department and self.use_disease:
            utterance_hidden_size = utterance_hidden_size * 3
        elif self.use_department or self.use_disease:
            utterance_hidden_size = utterance_hidden_size * 2

        self.classifier = nn.Linear(utterance_hidden_size, self.num_labels)

        if self.decoder == 'softmax':
            self.loss_fct = CrossEntropyLoss(ignore_index=0)
        elif self.decoder == 'crf':
            self.crf = CRF(self.num_labels, batch_first=True)
        else:
            raise ValueError()

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                labels=None,
                valid_ids=None,
                attention_mask_label=None,
                label_mask=None,
                party_mask=None,
                party_ids=None,
                department_ids=None,
                disease_ids=None,
                input_ngram_ids=None,
                ngram_position_matrix=None):

        batch_size = input_ids.shape[0]
        dialog_length = input_ids.shape[1]
        utterance_length = input_ids.shape[2]

        input_ids = input_ids.view(batch_size * dialog_length,
                                   utterance_length)
        token_type_ids = token_type_ids.view(batch_size * dialog_length,
                                             utterance_length)
        attention_mask = attention_mask.view(batch_size * dialog_length,
                                             utterance_length)

        if self.bert is not None:
            sequence_output, _ = self.bert(input_ids,
                                           token_type_ids,
                                           attention_mask,
                                           output_all_encoded_layers=False)
        elif self.zen is not None:
            ngram_position_matrix = ngram_position_matrix.view(
                batch_size * dialog_length, utterance_length, -1)
            input_ngram_ids = input_ngram_ids.view(batch_size * dialog_length,
                                                   -1)
            sequence_output, _ = self.zen(
                input_ids,
                input_ngram_ids=input_ngram_ids,
                ngram_position_matrix=ngram_position_matrix,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                output_all_encoded_layers=False)
        else:
            raise ValueError()

        word_embedding_c = None
        if self.use_memory:
            word_embedding_c = self.memory.memory_embeddings(input_ids)

        tmp_sequence_output = sequence_output.view(batch_size, dialog_length,
                                                   utterance_length, -1)
        # word_embedding_a = word_embedding_a.view(batch_size, dialog_length, -1)

        sequence_output = tmp_sequence_output[:, :, 0]
        tmp_label_mask = torch.stack([label_mask] * sequence_output.shape[-1],
                                     2)
        sequence_output = torch.mul(sequence_output, tmp_label_mask)

        if self.use_memory:
            word_embedding_c = word_embedding_c.view(batch_size, dialog_length,
                                                     -1)
            memory_output = self.memory(word_embedding_c, sequence_output,
                                        party_mask)
            sequence_output = torch.cat((sequence_output, memory_output), 2)
        sequence_output = self.dropout(sequence_output)
        #
        if self.use_party:
            party_embeddings = self.party_embedding(party_ids)
            sequence_output = torch.cat((sequence_output, party_embeddings),
                                        dim=2)
        #
        if self.utterance_encoder is not None:
            self.utterance_encoder.flatten_parameters()
            utterance_output, _ = self.utterance_encoder(sequence_output)
        else:
            utterance_output = sequence_output

        if self.use_department:
            department_embeddings = self.department_embedding(department_ids)
            utterance_output = torch.cat(
                (utterance_output, department_embeddings), dim=2)
        if self.use_disease:
            disease_embeddings = self.disease_embedding(disease_ids)
            utterance_output = torch.cat(
                (utterance_output, disease_embeddings), dim=2)

        tmp_label_mask = torch.stack([label_mask] * utterance_output.shape[-1],
                                     2)
        utterance_output = torch.mul(utterance_output, tmp_label_mask)
        logits = self.classifier(utterance_output)

        if labels is not None:
            if self.decoder == 'softmax':
                total_loss = self.loss_fct(logits.view(-1, self.num_labels),
                                           labels.view(-1))
            elif self.decoder == 'crf':
                total_loss = -1 * self.crf(
                    emissions=logits, tags=labels, mask=attention_mask_label)
            else:
                raise ValueError()
            return total_loss
        else:
            if self.decoder == 'softmax':
                scores = torch.argmax(nn.functional.log_softmax(logits, dim=2),
                                      dim=2)
            elif self.decoder == 'crf':
                scores = self.crf.decode(logits, attention_mask_label)[0]
            else:
                raise ValueError()
            return scores

    @staticmethod
    def init_hyper_parameters(args):
        hyper_parameters = DEFAULT_HPARA.copy()
        hyper_parameters['max_seq_length'] = args.max_seq_length
        hyper_parameters['max_ngram_size'] = args.max_ngram_size
        hyper_parameters['use_bert'] = args.use_bert
        hyper_parameters['use_zen'] = args.use_zen
        hyper_parameters['do_lower_case'] = args.do_lower_case
        hyper_parameters['use_memory'] = args.use_memory
        hyper_parameters['use_party'] = args.use_party
        hyper_parameters['use_department'] = args.use_department
        hyper_parameters['use_disease'] = args.use_disease
        hyper_parameters['utterance_encoder'] = args.utterance_encoder
        hyper_parameters['decoder'] = args.decoder
        hyper_parameters['lstm_hidden_size'] = args.lstm_hidden_size
        hyper_parameters['max_dialog_length'] = args.max_dialog_length
        return hyper_parameters

    @classmethod
    def load_model(cls, model_path):
        label2id = load_json(os.path.join(model_path, 'label2id.json'))
        hpara = load_json(os.path.join(model_path, 'hpara.json'))

        department2id_path = os.path.join(model_path, 'department2id.json')
        department2id = load_json(department2id_path) if os.path.exists(
            department2id_path) else None

        word2id_path = os.path.join(model_path, 'word2id.json')
        word2id = load_json(word2id_path) if os.path.exists(
            word2id_path) else None

        disease2id_path = os.path.join(model_path, 'disease2id.json')
        disease2id = load_json(disease2id_path) if os.path.exists(
            disease2id_path) else None

        res = cls(model_path=model_path,
                  label2id=label2id,
                  hpara=hpara,
                  department2id=department2id,
                  word2id=word2id,
                  disease2id=disease2id)

        res.load_state_dict(
            torch.load(os.path.join(model_path, 'pytorch_model.bin')))
        return res

    def save_model(self, output_dir, vocab_dir):
        output_model_path = os.path.join(output_dir, 'pytorch_model.bin')
        torch.save(self.state_dict(), output_model_path)

        label_map_file = os.path.join(output_dir, 'label2id.json')

        if not os.path.exists(label_map_file):
            save_json(label_map_file, self.label2id)

            save_json(os.path.join(output_dir, 'hpara.json'), self.hpara)
            if self.department2id is not None:
                save_json(os.path.join(output_dir, 'department2id.json'),
                          self.department2id)
            if self.word2id is not None:
                save_json(os.path.join(output_dir, 'word2id.json'),
                          self.word2id)
            if self.disease2id is not None:
                save_json(os.path.join(output_dir, 'disease2id.json'),
                          self.disease2id)

            output_config_file = os.path.join(output_dir, 'config.json')
            with open(output_config_file, "w", encoding='utf-8') as writer:
                if self.bert:
                    writer.write(self.bert.config.to_json_string())
                elif self.zen:
                    writer.write(self.zen.config.to_json_string())
                else:
                    raise ValueError()
            output_bert_config_file = os.path.join(output_dir,
                                                   'bert_config.json')
            command = 'cp ' + str(output_config_file) + ' ' + str(
                output_bert_config_file)
            subprocess.run(command, shell=True)

            if self.bert or self.zen:
                vocab_name = 'vocab.txt'
            else:
                raise ValueError()
            vocab_path = os.path.join(vocab_dir, vocab_name)
            command = 'cp ' + str(vocab_path) + ' ' + str(
                os.path.join(output_dir, vocab_name))
            subprocess.run(command, shell=True)

            if self.zen:
                ngram_name = 'ngram.txt'
                ngram_path = os.path.join(vocab_dir, ngram_name)
                command = 'cp ' + str(ngram_path) + ' ' + str(
                    os.path.join(output_dir, ngram_name))
                subprocess.run(command, shell=True)

    @staticmethod
    def data2example(data, flag=''):
        examples = []
        for i, (utterance, label, party, summary, max_utterance_len,
                party_mask, department, disease) in enumerate(data):
            guid = "%s-%s" % (flag, i)
            text_a = utterance
            text_b = None
            examples.append(
                InputExample(guid=guid,
                             text_a=text_a,
                             text_b=text_b,
                             label=label,
                             party=party,
                             summary=summary,
                             max_utterance_len=max_utterance_len,
                             party_mask=party_mask,
                             department=department,
                             disease=disease))
        return examples

    def convert_examples_to_features(self, examples):

        features = []

        tokenizer = self.zen_tokenizer if self.zen_tokenizer is not None else self.bert_tokenizer

        # -------- max ngram size --------
        max_utterance_length = min(
            int(
                max([example.max_utterance_len
                     for example in examples]) * 1.1 + 2), self.max_seq_length)
        max_seq_length = max_utterance_length
        max_dialog_length = min(
            max(max([len(example.text_a) for example in examples]), 1),
            self.max_dialog_length)
        # -------- max ngram size --------

        for (ex_index, example) in enumerate(examples):
            valid = [[] for _ in range(max_dialog_length)]
            tokens = [[] for _ in range(max_dialog_length)]
            segment_ids = [[] for _ in range(max_dialog_length)]
            input_ids = [[] for _ in range(max_dialog_length)]
            input_mask = [[] for _ in range(max_dialog_length)]
            input_id_len = [1 for _ in range(max_dialog_length)]
            party_mask = [[] for _ in range(max_dialog_length)]

            for i in range(max_dialog_length):
                if i < len(example.text_a):
                    utterance = example.text_a[i]
                    party = example.party[i]
                    if party == 'P':
                        party_mask[i] = example.party_mask['P']
                    elif party == 'D':
                        party_mask[i] = example.party_mask['D']
                    else:
                        raise ValueError()
                    if len(party_mask[i]) > max_dialog_length:
                        party_mask[i] = party_mask[i][:max_dialog_length]
                    while len(party_mask[i]) < max_dialog_length:
                        party_mask[i].append(0)
                    for word in utterance:
                        token = tokenizer.tokenize(word)
                        tokens[i].extend(token)
                        for m in range(len(token)):
                            if m == 0:
                                valid[i].append(1)
                            else:
                                valid[i].append(0)
                    if len(tokens[i]) >= max_utterance_length - 1:
                        tokens[i] = tokens[i][0:(max_utterance_length - 2)]
                        valid[i] = valid[i][0:(max_utterance_length - 2)]

                    ntokens = []

                    ntokens.append("[CLS]")
                    segment_ids[i].append(0)

                    valid[i].insert(0, 1)

                    for token in tokens[i]:
                        ntokens.append(token)
                        segment_ids[i].append(0)
                    ntokens.append("[SEP]")

                    segment_ids[i].append(0)
                    valid[i].append(1)

                    # ntokens: ['[CLS]', '我' ... , '人', '[SEP]'] length: 5 + 2
                    # valid: [1, ..., 1] length 5 + 2 (前后加 1)
                    # label_mask: [1, ..., 1] length 5 + 2 (前后加 1)
                    # label_ids: [6, 5, 5, 2, 3, 4, 7] (前后加 [CLS] 和 [SEP] 的标签) length 5 + 2
                    # segment_id: [0, 0, ..., 0] length 7

                    input_ids[i] = tokenizer.convert_tokens_to_ids(ntokens)
                    # input_ids: [1, 2, 3, .. , 7] length 7
                    for _ in range(len(input_ids[i])):
                        input_mask[i].append(1)

                input_id_len[i] = len(input_ids[i])
                while len(input_ids[i]) < max_utterance_length:
                    input_ids[i].append(0)
                    input_mask[i].append(0)
                    segment_ids[i].append(0)
                    valid[i].append(1)
                while len(party_mask[i]) < max_dialog_length:
                    party_mask[i].append(0)
                assert len(input_ids[i]) == len(input_mask[i])
                assert len(input_ids[i]) == len(segment_ids[i])
                assert len(input_ids[i]) == len(valid[i])

            assert len(input_ids) == max_dialog_length
            assert len(input_ids[-1]) == max_utterance_length

            labellist = example.label
            label_mask = []
            label_ids = []

            for label in labellist:
                label_id = self.label2id[
                    label] if label in self.label2id else self.label2id['<UNK>']
                label_ids.append(label_id)
                label_mask.append(1)
            if len(label_ids) > max_dialog_length:
                label_ids = label_ids[:max_dialog_length]
                label_mask = label_mask[:max_dialog_length]
            while len(label_ids) < max_dialog_length:
                label_ids.append(0)
                label_mask.append(0)

            partylist = example.party

            if self.party2id is not None:
                party_ids = []
                for party in partylist:
                    party_ids.append(self.party2id[party])
                if len(party_ids) > max_dialog_length:
                    party_ids = party_ids[:max_dialog_length]
                while len(party_ids) < max_dialog_length:
                    party_ids.append(0)
            else:
                party_ids = None

            if self.department2id is not None:
                department_ids = []
                if example.department in self.department2id:
                    department_id = self.department2id[example.department]
                else:
                    department_id = self.department2id['<UNK>']

                for _ in partylist:
                    department_ids.append(department_id)

                if len(department_ids) > max_dialog_length:
                    department_ids = department_ids[:max_dialog_length]

                while len(department_ids) < max_dialog_length:
                    department_ids.append(0)
            else:
                department_ids = None

            if self.disease2id is not None:
                disease_ids = []
                if example.disease in self.disease2id:
                    disease_id = self.disease2id[example.disease]
                else:
                    disease_id = self.disease2id['<UNK>']

                for _ in partylist:
                    disease_ids.append(disease_id)

                if len(disease_ids) > max_dialog_length:
                    disease_ids = disease_ids[:max_dialog_length]
                while len(disease_ids) < max_dialog_length:
                    disease_ids.append(0)
            else:
                disease_ids = None

            assert len(label_ids) == len(label_mask)
            assert len(label_ids) == max_dialog_length
            assert len(label_ids) == len(party_mask)
            assert len(label_ids) == len(party_mask[-1])

            if self.zen_ngram_dict is not None:
                all_ngram_ids = []
                all_ngram_positions_matrix = []
                # all_ngram_lengths = []
                # all_ngram_tuples = []
                # all_ngram_seg_ids = []
                # all_ngram_mask_array = []

                for token_list in tokens:
                    ngram_matches = []
                    #  Filter the ngram segment from 2 to 7 to check whether there is a ngram
                    for p in range(2, 8):
                        for q in range(0, len(token_list) - p + 1):
                            character_segment = token_list[q:q + p]
                            # j is the starting position of the ngram
                            # i is the length of the current ngram
                            character_segment = tuple(character_segment)
                            if character_segment in self.zen_ngram_dict.ngram_to_id_dict:
                                ngram_index = self.zen_ngram_dict.ngram_to_id_dict[
                                    character_segment]
                                ngram_matches.append(
                                    [ngram_index, q, p, character_segment])

                    # random.shuffle(ngram_matches)
                    ngram_matches = sorted(ngram_matches, key=lambda s: s[0])

                    max_ngram_in_seq_proportion = math.ceil(
                        (len(token_list) / max_seq_length) *
                        self.zen_ngram_dict.max_ngram_in_seq)
                    if len(ngram_matches) > max_ngram_in_seq_proportion:
                        ngram_matches = ngram_matches[:
                                                      max_ngram_in_seq_proportion]

                    ngram_ids = [ngram[0] for ngram in ngram_matches]
                    ngram_positions = [ngram[1] for ngram in ngram_matches]
                    ngram_lengths = [ngram[2] for ngram in ngram_matches]
                    # ngram_tuples = [ngram[3] for ngram in ngram_matches]
                    # ngram_seg_ids = [0 if position < (len(tokens) + 2) else 1 for position in ngram_positions]

                    ngram_mask_array = np.zeros(
                        self.zen_ngram_dict.max_ngram_in_seq, dtype=np.bool)
                    ngram_mask_array[:len(ngram_ids)] = 1

                    # record the masked positions
                    ngram_positions_matrix = np.zeros(
                        shape=(max_seq_length,
                               self.zen_ngram_dict.max_ngram_in_seq),
                        dtype=np.int32)
                    for i in range(len(ngram_ids)):
                        ngram_positions_matrix[
                            ngram_positions[i]:ngram_positions[i] +
                            ngram_lengths[i], i] = 1.0

                    # Zero-pad up to the max ngram in seq length.
                    padding = [0] * (self.zen_ngram_dict.max_ngram_in_seq -
                                     len(ngram_ids))
                    ngram_ids += padding
                    # ngram_lengths += padding
                    # ngram_seg_ids += padding

                    all_ngram_ids.append(ngram_ids)
                    all_ngram_positions_matrix.append(ngram_positions_matrix)
                    # all_ngram_lengths.append(ngram_lengths)
                    # all_ngram_tuples.append(ngram_tuples)
                    # all_ngram_seg_ids.append(ngram_seg_ids)
                    # all_ngram_mask_array.append(ngram_mask_array)
                while len(all_ngram_ids) < max_dialog_length:
                    all_ngram_ids.append([0] *
                                         self.zen_ngram_dict.max_ngram_in_seq)
                    all_ngram_positions_matrix.append(
                        np.zeros(shape=(max_seq_length,
                                        self.zen_ngram_dict.max_ngram_in_seq),
                                 dtype=np.int32))
            else:
                all_ngram_ids = None
                all_ngram_positions_matrix = None
                # all_ngram_lengths = None
                # all_ngram_tuples = None
                # all_ngram_seg_ids = None
                # all_ngram_mask_array = None

            features.append(
                InputFeatures(
                    input_ids=input_ids,
                    input_mask=input_mask,
                    segment_ids=segment_ids,
                    label_id=label_ids,
                    valid_ids=valid,
                    label_mask=label_mask,
                    input_id_len=input_id_len,
                    party_mask=party_mask,
                    party=party_ids,
                    department=department_ids,
                    disease=disease_ids,
                    ngram_ids=all_ngram_ids,
                    ngram_positions=all_ngram_positions_matrix,
                    # ngram_lengths=all_ngram_lengths,
                    # ngram_tuples=all_ngram_tuples,
                    # ngram_seg_ids=all_ngram_seg_ids,
                    # ngram_masks=all_ngram_mask_array
                ))
        return features

    def feature2input(self, device, feature):
        all_input_ids = torch.tensor([f.input_ids for f in feature],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in feature],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in feature],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in feature],
                                     dtype=torch.long)
        all_valid_ids = torch.tensor([f.valid_ids for f in feature],
                                     dtype=torch.long)
        all_lmask_ids = torch.tensor([f.label_mask for f in feature],
                                     dtype=torch.long)
        input_ids = all_input_ids.to(device)
        input_mask = all_input_mask.to(device)
        segment_ids = all_segment_ids.to(device)
        label_ids = all_label_ids.to(device)
        valid_ids = all_valid_ids.to(device)
        l_mask = all_lmask_ids.to(device)

        all_lmask = torch.tensor([f.label_mask for f in feature],
                                 dtype=torch.float)
        lmask = all_lmask.to(device)

        if self.memory is not None:
            all_party_mask = torch.tensor([f.party_mask for f in feature],
                                          dtype=torch.float)

            party_mask = all_party_mask.to(device)
        else:
            party_mask = None

        if self.use_party:
            all_party_ids = torch.tensor([f.party for f in feature],
                                         dtype=torch.long)
            party_ids = all_party_ids.to(device)
        else:
            party_ids = None

        if self.use_department:
            all_department_ids = torch.tensor([f.department for f in feature],
                                              dtype=torch.long)
            department_ids = all_department_ids.to(device)
        else:
            department_ids = None

        if self.use_disease:
            all_disease_ids = torch.tensor([f.disease for f in feature],
                                           dtype=torch.long)
            disease_ids = all_disease_ids.to(device)
        else:
            disease_ids = None

        if self.zen is not None:
            all_ngram_ids = torch.tensor([f.ngram_ids for f in feature],
                                         dtype=torch.long)
            all_ngram_positions = torch.tensor(
                [f.ngram_positions for f in feature], dtype=torch.long)
            # all_ngram_lengths = torch.tensor([f.ngram_lengths for f in train_features], dtype=torch.long)
            # all_ngram_seg_ids = torch.tensor([f.ngram_seg_ids for f in train_features], dtype=torch.long)
            # all_ngram_masks = torch.tensor([f.ngram_masks for f in train_features], dtype=torch.long)

            ngram_ids = all_ngram_ids.to(device)
            ngram_positions = all_ngram_positions.to(device)
        else:
            ngram_ids = None
            ngram_positions = None
        return input_ids, input_mask, l_mask, label_ids, ngram_ids, ngram_positions, segment_ids, valid_ids, \
               lmask, party_mask, party_ids, department_ids, disease_ids