def update_arg_dict(self, arg_dict): super().update_arg_dict(arg_dict) if self.arg_dict['concatenate_input_for_gcn_hidden']: self.arg_dict['fully_scales'][0] += self.arg_dict['gcn_hidden_dim'] if self.arg_dict['repeat_train']: time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) if self.arg_dict['group_layer_limit_flag']: gl = self.arg_dict['group_layer_limit_list'] else: gl = self.arg_dict['gcn_layer'] model_dir = file_tool.connect_path( self.result_path, 'train', 'bs:{}-lr:{}-gl:{}--com_fun:{}'.format( self.arg_dict['batch_size'], self.arg_dict['learn_rate'], gl, self.arg_dict['semantic_compare_func']), time_str) else: model_dir = file_tool.connect_path(self.result_path, 'test') file_tool.makedir(model_dir) if not file_tool.check_dir(model_dir): raise RuntimeError self.arg_dict['model_path'] = model_dir
def create_framework(self): self.framework = self.get_framework() gpu_id = self.arg_dict['ues_gpu'] if gpu_id == -1: self.framework.cpu() else: self.framework.cuda(self.device) if self.arg_dict['optimizer'] == 'sgd': self.optimizer = torch.optim.SGD( self.framework.parameters(), lr=self.arg_dict['learn_rate'], momentum=self.arg_dict['sgd_momentum']) elif self.arg_dict['optimizer'] == 'adam': self.optimizer = torch.optim.Adam(self.framework.parameters(), lr=self.arg_dict['learn_rate']) else: raise ValueError # self.loser = fr.BiCELo(self.arg_dict, *self.framework.get_regular_parts()) checkpoint_path = file_tool.connect_path( self.framework.arg_dict['model_path'], 'checkpoint') file_tool.makedir(checkpoint_path) self.entire_model_state_dict_file = file_tool.connect_path( checkpoint_path, 'entire_model.pt') self.optimizer_state_dict_file = file_tool.connect_path( checkpoint_path, 'optimizer.pt') if not self.arg_dict['repeat_train']: if gpu_id == -1: self.framework.load_state_dict( torch.load( file_tool.PathManager.change_filename_by_append( self.entire_model_state_dict_file, 'cpu'))) else: self.framework.load_state_dict( torch.load(self.entire_model_state_dict_file)) self.optimizer.load_state_dict( torch.load(self.optimizer_state_dict_file)) self.logger = log_tool.get_logger( self.framework_logger_name, file_tool.connect_path(self.framework.arg_dict['model_path'], 'log.txt')) self.logger.info('{} was created!'.format(self.framework.name)) self.__print_framework_parameter__() self.__print_framework_arg_dict__()
def train_model(self): train_loader_tuple_list = self.data_loader_dict[ 'train_loader_tuple_list'] avg_result = np.array([0, 0], dtype=np.float) record_list = [] for tuple_index, train_loader_tuple in enumerate( train_loader_tuple_list, 1): # if tuple_index>2: # break #repeat create framework, when each fold train self.create_framework() self.logger.info('{} was created!'.format(self.framework.name)) train_loader, valid_loader = train_loader_tuple self.logger.info('train_loader:{} valid_loader:{}'.format( len(train_loader), len(valid_loader))) self.logger.info('begin train {}-th fold'.format(tuple_index)) result = self.__train_fold__(train_loader=train_loader, valid_loader=valid_loader) self.trial_step = self.arg_dict['epoch'] * tuple_index avg_result += np.array(result[0:2], dtype=np.float) record_list.append(result[3]) record_file = file_tool.connect_path( self.framework.arg_dict['model_path'], 'record_list.pkl') file_tool.save_data_pickle(record_list, record_file) avg_result = (avg_result / len(train_loader_tuple_list)).tolist() avg_result.append('finish') self.logger.info('avg_acc:{}'.format(avg_result[0])) return avg_result
def create_sentences(self): original_sentence_dict = self.__extra_sentences_from_org_file__( file_tool.connect_path(self.data_path, 'data.txt'), file_tool.connect_path(self.data_path, 'original_sentence_dict.pkl')) self.sentence_list = [] self.sentence_dict = {} for sent_id, o_sent in original_sentence_dict.items(): sent_id = int(sent_id) sent_obj = base_corpus.Sentence(id_=sent_id, original_sentence=o_sent) self.sentence_list.append(sent_obj) if str(sent_id) in self.sentence_dict: raise ValueError("sentence in corpus is repeated") self.sentence_dict[str(sent_id)] = sent_obj pass
def show_sent_len_distribute(): sent_len_table = file_tool.load_data_pickle( file_tool.connect_path(Qqp.data_path, 'sent_len_table.pkl')) plt.bar(range(1, len(sent_len_table) + 1), sent_len_table) plt.title("") plt.xlabel('sentence length') plt.ylabel('count') plt.xlim(0, 80) plt.show()
def update_arg_dict(self, arg_dict): super().update_arg_dict(arg_dict) if self.arg_dict['repeat_train']: time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) model_dir = file_tool.connect_path( self.result_path, 'train', 'bs:{}-lr:{}-com_fun:{}'.format( self.arg_dict['batch_size'], self.arg_dict['learn_rate'], self.arg_dict['semantic_compare_func']), time_str) else: model_dir = file_tool.connect_path(self.result_path, 'test') file_tool.makedir(model_dir) if not file_tool.check_dir(model_dir): raise RuntimeError self.arg_dict['model_path'] = model_dir
def test_model(self): def get_save_data(error_example_ids): save_data = [] for e_id in error_example_ids: e_id = str(e_id) example = example_dict[e_id] sentence1 = example.sentence1 sentence2 = example.sentence2 save_data.append(str([sentence1.id, sentence2.id])) save_data.append(sentence1.original) save_data.append(sentence2.original) return save_data if not self.arg_dict['repeat_train']: self.create_framework() test_loader = self.data_manager.test_loader( self.arg_dict['batch_size']) self.logger.info('test_loader length:{}'.format(len(test_loader))) with torch.no_grad(): evaluation_result = self.evaluation_calculation(test_loader) self.logger.info(evaluation_result['metric']) example_dict = test_loader.example_dict fn_error_example_ids = evaluation_result['error_example_ids_dict'][ 'FN'] fp_error_example_ids = evaluation_result['error_example_ids_dict'][ 'FP'] fn_sava_data = get_save_data(fn_error_example_ids) fp_sava_data = get_save_data(fp_error_example_ids) error_file_path = file_tool.connect_path( self.framework.arg_dict['model_path'], 'error_file') file_tool.makedir(error_file_path) file_tool.save_list_data( fn_sava_data, file_tool.connect_path(error_file_path, 'fn_error_sentence_pairs.txt'), 'w') file_tool.save_list_data( fp_sava_data, file_tool.connect_path(error_file_path, 'fp_error_sentence_pairs.txt'), 'w') return evaluation_result['metric'] pass
def parse_sentences(self): parsed_sentence_org_file = file_tool.connect_path( self.data_path, 'parsed_sentences.txt') parsed_sentence_dict_file = file_tool.connect_path( self.data_path, 'parsed_sentence_dict.pkl') if file_tool.check_file(parsed_sentence_dict_file): parsed_sentence_dict = file_tool.load_data_pickle( parsed_sentence_dict_file) else: parsed_sentence_dict = parser_tool.extra_parsed_sentence_dict_from_org_file( parsed_sentence_org_file) file_tool.save_data_pickle(parsed_sentence_dict, parsed_sentence_dict_file) if len(parsed_sentence_dict) != len(self.sentence_dict): raise ValueError("parsed_sentence_dict not march sentence_dict") if not general_tool.compare_two_dict_keys(self.sentence_dict.copy(), parsed_sentence_dict.copy()): raise ValueError("parsed_sentence_dict not march sentence_dict") for sent_id, info in parsed_sentence_dict.items(): if info['original'] != self.sentence_dict[sent_id].original: raise ValueError( "parsed_sentence_dict not march sentence_dict") for sent_id, parse_info in parsed_sentence_dict.items(): sent_id = str(sent_id) self.sentence_dict[sent_id].parse_info = parse_info self.parse_info = parser_tool.process_parsing_sentence_dict( parsed_sentence_dict, modify_dep_name=True) numeral_sentence_dict = self.parse_info.numeral_sentence_dict self.max_sent_len = self.parse_info.max_sent_len if not general_tool.compare_two_dict_keys( self.sentence_dict.copy(), numeral_sentence_dict.copy()): raise ValueError("numeral_sentence_dict not march sentence_dict") for sent_id in self.sentence_dict.keys(): self.sentence_dict[sent_id].syntax_info = numeral_sentence_dict[ sent_id] pass
def get_qqp_obj(force=False): global single_qqp_obj if force or (single_qqp_obj is None): single_qqp_obj_file = file_tool.connect_path("corpus/qqp", 'qqp_obj.pkl') if file_tool.check_file(single_qqp_obj_file): single_qqp_obj = file_tool.load_data_pickle(single_qqp_obj_file) else: single_qqp_obj = Qqp() file_tool.save_data_pickle(single_qqp_obj, single_qqp_obj_file) return single_qqp_obj
def create_examples(self): def create_examples_by_dicts(examples): example_obj_list = [] example_obj_dict = {} for e in examples: sentence1_id = str(e['sent_id1']) sentence2_id = str(e['sent_id2']) sentence1 = self.sentence_dict[sentence1_id] sentence2 = self.sentence_dict[sentence2_id] id_ = str(e['id']) label = int(e['label']) example_obj = base_corpus.Example(id_, sentence1=sentence1, sentence2=sentence2, label=label) example_obj_list.append(example_obj) if id_ in example_obj_list: raise ValueError("example in corpus is repeated") example_obj_dict[id_] = example_obj return example_obj_list, example_obj_dict train_dicts = self.__extra_examples_from_org_file__( file_tool.connect_path(self.data_path, 'train.txt'), file_tool.connect_path(self.data_path, 'train_dicts.pkl')) test_dicts = self.__extra_examples_from_org_file__( file_tool.connect_path(self.data_path, 'test.txt'), file_tool.connect_path(self.data_path, 'test_dicts.pkl')) self.train_example_list, self.train_example_dict = create_examples_by_dicts( train_dicts) self.test_example_list, self.test_example_dict = create_examples_by_dicts( test_dicts) pass
def visualize_model(self): self.create_framework() train_loader = self.data_manager.train_loader( self.arg_dict['batch_size']) batch = iter(train_loader).next() example_ids = batch['example_id'] input_data = self.framework.get_input_of_visualize_model( example_ids, train_loader.example_dict) visualization_path = file_tool.connect_path(self.framework.result_path, 'visualization') file_tool.makedir(visualization_path) filename = visualization_tool.create_filename(visualization_path) visualization_tool.log_graph(filename=filename, nn_model=self.framework, input_data=input_data)
def show_pared_info(self): print('the count of dep type:{}'.format( self.parse_info.dependency_count)) print( 'the max len of sentence_tokens:{}, correspond sent id:{}'.format( self.parse_info.max_sent_len, self.parse_info.max_sent_id)) print('the average len of sentence_tokens:{}'.format( self.parse_info.avg_sent_len)) sent_len_table = self.parse_info.sent_len_table file_tool.save_data_pickle( sent_len_table, file_tool.connect_path(self.data_path, "sent_len_table.pkl")) plt.bar(range(1, len(sent_len_table) + 1), sent_len_table) plt.title("sentence tokens length distribution") plt.show()
def create_examples(self): def create_examples_by_dicts(examples): example_obj_list = [] example_obj_dict = {} repeat_qes_exam_list = [] for e in examples: id_ = str(e['id']) label = int(e['label']) qes1_id = str(e['qes_id1']) qes2_id = str(e['qes_id2']) if (qes1_id not in self.sentence_dict) or ( qes2_id not in self.sentence_dict): print('example id {} q1 id {} q2 id {} is invalid'.format( id_, qes1_id, qes2_id)) continue sent_obj1 = self.sentence_dict[qes1_id] sent_obj2 = self.sentence_dict[qes2_id] if (e['qes1'] != sent_obj1.original_sentence()) or ( e['qes2'] != sent_obj2.original_sentence()): raise ValueError("sentence load error") example_obj = base_corpus.Example(id_, sentence1=sent_obj1, sentence2=sent_obj2, label=label) if id_ in example_obj_dict: raise ValueError("example in corpus is repeated") if example_obj.sentence1 == example_obj.sentence2: raise ValueError("example in corpus is repeated") if example_obj.sentence1.id == example_obj.sentence2.id: raise ValueError("example in corpus is repeated") if example_obj.sentence1.original == example_obj.sentence2.original: # raise ValueError("example in corpus is repeated") repeat_qes_exam_list.append(example_obj) example_obj_list.append(example_obj) example_obj_dict[id_] = example_obj if len(example_obj_list) != len(example_obj_dict): raise ValueError("example in corpus is repeated") print("repeat question example count:{}".format( len(repeat_qes_exam_list))) return example_obj_list, example_obj_dict example_dicts = self.__extra_examples_from_org_file__( file_tool.connect_path(self.data_path, 'data.tsv'), file_tool.connect_path(self.data_path, 'example_dicts.pkl')) train_dicts, test_dicts = self.__divide_example_dicts_to_train_and_test__( example_dicts) self.train_example_list, self.train_example_dict = create_examples_by_dicts( train_dicts) self.test_example_list, self.test_example_dict = create_examples_by_dicts( test_dicts) for e_id in self.test_example_dict.keys(): if e_id in self.train_example_dict: raise ValueError( "example {} in both test and train".format(e_id)) pass
class LE(fr.LSeE): name = "LE" result_path = file_tool.connect_path('result', name) def __init__(self, arg_dict): super().__init__(arg_dict) self.name = LE.name self.result_path = LE.result_path @classmethod def framework_name(cls): return cls.name def create_arg_dict(self): arg_dict = { 'semantic_compare_func': 'l2', 'fully_scales': [768, 2], # 'fully_regular': 1e-4, # 'bert_regular': 1e-4, 'bert_hidden_dim': 768, 'pad_on_right': True, 'sentence_max_len_for_bert': 128, 'dtype': torch.float32, } return arg_dict def create_models(self): self.bert = BertBase() self.semantic_layer = SemanticLayer(self.arg_dict) self.fully_connection = FullyConnection(self.arg_dict) def forward(self, *input_data, **kwargs): if len(kwargs) > 0: # common run or visualization data_batch = kwargs input_ids_batch = data_batch['input_ids_batch'] token_type_ids_batch = data_batch['token_type_ids_batch'] attention_mask_batch = data_batch['attention_mask_batch'] sep_index_batch = data_batch['sep_index_batch'] word_piece_flags_batch = data_batch['word_piece_flags_batch'] sent1_len_batch = data_batch['sent1_len_batch'] sent2_len_batch = data_batch['sent2_len_batch'] labels = data_batch['labels'] sent1_org_len_batch = data_batch['sent1_org_len_batch'] sent2_org_len_batch = data_batch['sent2_org_len_batch'] else: input_ids_batch, token_type_ids_batch, attention_mask_batch, sep_index_batch, sent1_len_batch, \ sent2_len_batch, labels = input_data last_hidden_states_batch, pooled_output = self.bert( input_ids_batch, token_type_ids_batch, attention_mask_batch) sent1_states_batch = [] sent2_states_batch = [] for i, hidden_states in enumerate(last_hidden_states_batch): sent1_word_piece_flags = word_piece_flags_batch[i][ 1:sep_index_batch[i]] sent1_states = hidden_states[1:sep_index_batch[i]] sent2_word_piece_flags = word_piece_flags_batch[ i][sep_index_batch[i] + 1:sep_index_batch[i] + 1 + sent2_len_batch[i]] sent2_states = hidden_states[sep_index_batch[i] + 1:sep_index_batch[i] + 1 + sent2_len_batch[i]] if len(sent1_states) != sent1_len_batch[i] or len( sent2_states) != sent2_len_batch[i]: raise ValueError if len(sent1_states) + len( sent2_states) + 3 != attention_mask_batch[i].sum(): raise ValueError if len(word_piece_flags_batch[i]) != attention_mask_batch[i].sum(): raise ValueError sent1_states = self.merge_reps_of_word_pieces( sent1_word_piece_flags, sent1_states) if len(sent1_states) != sent1_org_len_batch[i]: raise ValueError sent1_states = data_tool.padding_tensor( sent1_states, self.arg_dict['max_sentence_length'], align_dir='left', dim=0) sent2_states = self.merge_reps_of_word_pieces( sent2_word_piece_flags, sent2_states) if len(sent2_states) != sent2_org_len_batch[i]: raise ValueError sent2_states = data_tool.padding_tensor( sent2_states, self.arg_dict['max_sentence_length'], align_dir='left', dim=0) sent1_states_batch.append(sent1_states) sent2_states_batch.append(sent2_states) sent1_states_batch = torch.stack(sent1_states_batch, dim=0) sent2_states_batch = torch.stack(sent2_states_batch, dim=0) result = self.semantic_layer(sent1_states_batch, sent2_states_batch) result = self.fully_connection(result) loss = torch.nn.CrossEntropyLoss()(result.view(-1, 2), labels.view(-1)) predicts = np.array(result.detach().cpu().numpy()).argmax(axis=1) return loss, predicts
class LSeE(fr.Framework): name = "LSeE" result_path = file_tool.connect_path('result', name) def __init__(self, arg_dict): super().__init__(arg_dict) self.name = LSeE.name self.result_path = LSeE.result_path @classmethod def framework_name(cls): return cls.name def create_arg_dict(self): arg_dict = { 'semantic_compare_func': 'l2', 'fully_scales': [768 * 2, 2], # 'fully_regular': 1e-4, # 'bert_regular': 1e-4, 'bert_hidden_dim': 768, 'pad_on_right': True, 'sentence_max_len_for_bert': 128, 'dtype': torch.float32, } return arg_dict def update_arg_dict(self, arg_dict): super().update_arg_dict(arg_dict) if self.arg_dict['repeat_train']: time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) model_dir = file_tool.connect_path( self.result_path, 'train', 'bs:{}-lr:{}-com_fun:{}'.format( self.arg_dict['batch_size'], self.arg_dict['learn_rate'], self.arg_dict['semantic_compare_func']), time_str) else: model_dir = file_tool.connect_path(self.result_path, 'test') file_tool.makedir(model_dir) if not file_tool.check_dir(model_dir): raise RuntimeError self.arg_dict['model_path'] = model_dir def create_models(self): self.bert = BertBase() config = self.bert.config self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.semantic_layer = SemanticLayer(self.arg_dict) self.fully_connection = FullyConnection(self.arg_dict) def deal_with_example_batch(self, example_ids, example_dict): examples = [example_dict[str(e_id.item())] for e_id in example_ids] sentence_max_len = self.arg_dict['sentence_max_len_for_bert'] pad_on_right = self.arg_dict['pad_on_right'] pad_token = self.bert.tokenizer.convert_tokens_to_ids( [self.bert.tokenizer.pad_token])[0] sep_token = self.bert.tokenizer.convert_tokens_to_ids( [self.bert.tokenizer.sep_token])[0] mask_padding_with_zero = True pad_token_segment_id = 0 sentence1s = [e.sentence1 for e in examples] sentence2s = [e.sentence2 for e in examples] labels = torch.tensor([e.label for e in examples], dtype=torch.long, device=self.device) input_ids_batch = [] token_type_ids_batch = [] attention_mask_batch = [] sep_index_batch = [] sent1_len_batch = [] sent2_len_batch = [] sent1_org_len_batch = [] sent2_org_len_batch = [] word_piece_flags_batch = [] for s1, s2 in zip(sentence1s, sentence2s): inputs_ls_cased = self.bert.tokenizer.encode_plus( s1.sentence_with_root_head(), s2.sentence_with_root_head(), add_special_tokens=True, max_length=sentence_max_len, ) input_ids, token_type_ids = inputs_ls_cased[ "input_ids"], inputs_ls_cased["token_type_ids"] word_piece_flags_batch.append( general_tool.word_piece_flag_list( self.bert.tokenizer.convert_ids_to_tokens(input_ids), '##')) attention_mask = [1 if mask_padding_with_zero else 0 ] * len(input_ids) padding_length = sentence_max_len - len(input_ids) if not pad_on_right: input_ids = ([pad_token] * padding_length) + input_ids attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids else: input_ids = input_ids + ([pad_token] * padding_length) attention_mask = attention_mask + ( [0 if mask_padding_with_zero else 1] * padding_length) token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) input_ids_batch.append(input_ids) token_type_ids_batch.append(token_type_ids) attention_mask_batch.append(attention_mask) sep_indexes = [] for sep_index, id_ in enumerate(input_ids.copy()): if id_ == sep_token: sep_indexes.append(sep_index) if len(sep_indexes) != 2: raise ValueError sep_index_batch.append(sep_indexes[0]) sent1_len_batch.append(sep_indexes[0] - 1) sent2_len_batch.append(sep_indexes[1] - sep_indexes[0] - 1) sent1_org_len_batch.append(s1.len_of_tokens()) sent2_org_len_batch.append(s2.len_of_tokens()) input_ids_batch = torch.tensor(input_ids_batch, device=self.device) token_type_ids_batch = torch.tensor(token_type_ids_batch, device=self.device) attention_mask_batch = torch.tensor(attention_mask_batch, device=self.device) result = { 'input_ids_batch': input_ids_batch, 'token_type_ids_batch': token_type_ids_batch, 'attention_mask_batch': attention_mask_batch, 'sep_index_batch': sep_index_batch, 'sent1_len_batch': sent1_len_batch, 'sent1_org_len_batch': sent1_org_len_batch, 'word_piece_flags_batch': word_piece_flags_batch, 'sent2_len_batch': sent2_len_batch, 'sent2_org_len_batch': sent2_org_len_batch, 'labels': labels } return result def merge_reps_of_word_pieces(self, word_piece_flags, token_reps): result_reps = [] word_piece_label = False word_piece_rep = 0 word_piece_count = 0 for i, flag in enumerate(word_piece_flags): if flag == 1: word_piece_rep += token_reps[i] word_piece_count += 1 word_piece_label = True else: if word_piece_label: if word_piece_count == 0: raise ValueError result_reps.append(word_piece_rep / word_piece_count) word_piece_rep = 0 word_piece_count = 0 result_reps.append(token_reps[i]) word_piece_label = False if word_piece_label and (word_piece_count > 0): result_reps.append(word_piece_rep / word_piece_count) result_reps = torch.stack(result_reps, dim=0) return result_reps def forward(self, *input_data, **kwargs): if len(kwargs) > 0: # common run or visualization data_batch = kwargs input_ids_batch = data_batch['input_ids_batch'] token_type_ids_batch = data_batch['token_type_ids_batch'] attention_mask_batch = data_batch['attention_mask_batch'] sep_index_batch = data_batch['sep_index_batch'] word_piece_flags_batch = data_batch['word_piece_flags_batch'] sent1_len_batch = data_batch['sent1_len_batch'] sent2_len_batch = data_batch['sent2_len_batch'] labels = data_batch['labels'] # sent1_org_len_batch = data_batch['sent1_org_len_batch'] # sent2_org_len_batch = data_batch['sent2_org_len_batch'] else: input_ids_batch, token_type_ids_batch, attention_mask_batch, sep_index_batch, sent1_len_batch, \ sent2_len_batch, labels = input_data last_hidden_states_batch, pooled_output = self.bert( input_ids_batch, token_type_ids_batch, attention_mask_batch) pooled_output = self.dropout(pooled_output) sent1_states_batch = [] sent2_states_batch = [] for i, hidden_states in enumerate(last_hidden_states_batch): sent1_word_piece_flags = word_piece_flags_batch[i][ 1:sep_index_batch[i]] sent1_states = hidden_states[1:sep_index_batch[i]] sent2_word_piece_flags = word_piece_flags_batch[ i][sep_index_batch[i] + 1:sep_index_batch[i] + 1 + sent2_len_batch[i]] sent2_states = hidden_states[sep_index_batch[i] + 1:sep_index_batch[i] + 1 + sent2_len_batch[i]] if len(sent1_states) != sent1_len_batch[i] or len( sent2_states) != sent2_len_batch[i]: raise ValueError if len(sent1_states) + len( sent2_states) + 3 != attention_mask_batch[i].sum(): raise ValueError if len(word_piece_flags_batch[i]) != attention_mask_batch[i].sum(): raise ValueError sent1_states = self.merge_reps_of_word_pieces( sent1_word_piece_flags, sent1_states) # if len(sent1_states) != sent1_org_len_batch[i]: # raise ValueError sent1_states = data_tool.padding_tensor( sent1_states, self.arg_dict['max_sentence_length'], align_dir='left', dim=0) sent2_states = self.merge_reps_of_word_pieces( sent2_word_piece_flags, sent2_states) # if len(sent2_states) != sent2_org_len_batch[i]: # raise ValueError sent2_states = data_tool.padding_tensor( sent2_states, self.arg_dict['max_sentence_length'], align_dir='left', dim=0) sent1_states_batch.append(sent1_states) sent2_states_batch.append(sent2_states) sent1_states_batch = torch.stack(sent1_states_batch, dim=0) sent2_states_batch = torch.stack(sent2_states_batch, dim=0) result = self.semantic_layer(sent1_states_batch, sent2_states_batch) result = torch.cat([pooled_output, result], dim=1) result = self.fully_connection(result) loss = torch.nn.CrossEntropyLoss()(result.view(-1, 2), labels.view(-1)) predicts = np.array(result.detach().cpu().numpy()).argmax(axis=1) return loss, predicts def get_regular_parts(self): regular_part_list = (self.fully_connection, self.bert) regular_factor_list = (self.arg_dict['fully_regular'], self.arg_dict['bert_regular']) return regular_part_list, regular_factor_list def get_input_of_visualize_model(self, example_ids, example_dict): data_batch = self.deal_with_example_batch(example_ids[0:1], example_dict) input_ids_batch = data_batch['input_ids_batch'] token_type_ids_batch = data_batch['token_type_ids_batch'] attention_mask_batch = data_batch['attention_mask_batch'] sep_index_batch = torch.tensor(data_batch['sep_index_batch'], device=self.device) sent1_len_batch = torch.tensor(data_batch['sent1_len_batch'], device=self.device) sent2_len_batch = torch.tensor(data_batch['sent2_len_batch'], device=self.device) labels = data_batch['labels'] input_data = input_ids_batch, token_type_ids_batch, attention_mask_batch, sep_index_batch, sent1_len_batch, \ sent2_len_batch, labels return input_data def count_of_parameter(self): with torch.no_grad(): self.cpu() model_list = [self, self.bert, self.fully_connection] parameter_counts = [] weight_counts = [] bias_counts = [] parameter_list = [] weights_list = [] bias_list = [] for model_ in model_list: parameters_temp = model_.named_parameters() weights_list.clear() parameter_list.clear() bias_list.clear() for name, p in parameters_temp: # print(name) parameter_list.append(p.reshape(-1)) if name.find('weight') != -1: weights_list.append(p.reshape(-1)) if name.find('bias') != -1: bias_list.append(p.reshape(-1)) parameters = torch.cat(parameter_list, dim=0) weights = torch.cat(weights_list, dim=0) biases = torch.cat(bias_list, dim=0) parameter_counts.append(len(parameters)) weight_counts.append(len(weights)) bias_counts.append(len(biases)) for p_count, w_count, b_count in zip(parameter_counts, weight_counts, bias_counts): if p_count != w_count + b_count: raise ValueError for kind in (parameter_counts, weight_counts, bias_counts): total = kind[0] others = kind[1:] count_temp = 0 for other in others: count_temp += other if total != count_temp: raise ValueError self.to(self.device) result = [ { 'name': 'entire', 'total': parameter_counts[0], 'weight': weight_counts[0], 'bias': bias_counts[0] }, { 'name': 'bert', 'total': parameter_counts[1], 'weight': weight_counts[1], 'bias': bias_counts[1] }, { 'name': 'fully', 'total': parameter_counts[2], 'weight': weight_counts[2], 'bias': bias_counts[2] }, ] return result
def save_data(self): test_file = file_tool.connect_path(self.data_path, 'test.tsv') train_file = file_tool.connect_path(self.data_path, 'train.tsv') self.__create_new_data_set__(self.train_example_list, train_file) self.__create_new_data_set__(self.test_example_list, test_file)
class LSyE(fr.LSSE): name = "LSyE" result_path = file_tool.connect_path('result', name) def __init__(self, arg_dict): super().__init__(arg_dict) self.name = LSyE.name self.result_path = LSyE.result_path @classmethod def framework_name(cls): return cls.name def create_arg_dict(self): arg_dict = { # 'sgd_momentum': 0.4, 'semantic_compare_func': 'l2', 'concatenate_input_for_gcn_hidden': True, 'fully_scales': [768, 2], 'position_encoding': True, # 'fully_regular': 1e-4, # 'gcn_regular': 1e-4, # 'bert_regular': 1e-4, 'gcn_layer': 2, 'group_layer_limit_flag': False, # 'group_layer_limit_list': [2, 3, 4, 5, 6], 'gcn_gate_flag': True, 'gcn_norm_item': 0.5, 'gcn_self_loop_flag': True, 'gcn_hidden_dim': 768, 'bert_hidden_dim': 768, 'pad_on_right': True, 'sentence_max_len_for_bert': 128, 'dtype': torch.float32, } return arg_dict def create_models(self): self.bert = BertBase() self.gcn = GCN(self.arg_dict) self.semantic_layer = SemanticLayer(self.arg_dict) self.fully_connection = FullyConnection(self.arg_dict) self.gcn.apply(self.init_weights) self.fully_connection.apply(self.init_weights) def forward(self, *input_data, **kwargs): if len(kwargs) > 0: # common run or visualization data_batch = kwargs input_ids_batch = data_batch['input_ids_batch'] token_type_ids_batch = data_batch['token_type_ids_batch'] attention_mask_batch = data_batch['attention_mask_batch'] sep_index_batch = data_batch['sep_index_batch'] word_piece_flags_batch = data_batch['word_piece_flags_batch'] sent1_len_batch = data_batch['sent1_len_batch'] adj_matrix1_batch = data_batch['adj_matrix1_batch'] sent2_len_batch = data_batch['sent2_len_batch'] adj_matrix2_batch = data_batch['adj_matrix2_batch'] labels = data_batch['labels'] sent1_org_len_batch = data_batch['sent1_org_len_batch'] sent2_org_len_batch = data_batch['sent2_org_len_batch'] else: input_ids_batch, token_type_ids_batch, attention_mask_batch, sep_index_batch, sent1_len_batch, \ adj_matrix1_batch, sent2_len_batch, adj_matrix2_batch, labels = input_data last_hidden_states_batch, pooled_output = self.bert( input_ids_batch, token_type_ids_batch, attention_mask_batch) sent1_states_batch = [] sent2_states_batch = [] for i, hidden_states in enumerate(last_hidden_states_batch): sent1_word_piece_flags = word_piece_flags_batch[i][ 1:sep_index_batch[i]] sent1_states = hidden_states[1:sep_index_batch[i]] sent2_word_piece_flags = word_piece_flags_batch[ i][sep_index_batch[i] + 1:sep_index_batch[i] + 1 + sent2_len_batch[i]] sent2_states = hidden_states[sep_index_batch[i] + 1:sep_index_batch[i] + 1 + sent2_len_batch[i]] if len(sent1_states) != sent1_len_batch[i] or len( sent2_states) != sent2_len_batch[i]: raise ValueError if len(sent1_states) + len( sent2_states) + 3 != attention_mask_batch[i].sum(): raise ValueError if len(word_piece_flags_batch[i]) != attention_mask_batch[i].sum(): raise ValueError sent1_states = self.merge_reps_of_word_pieces( sent1_word_piece_flags, sent1_states) if len(sent1_states) != sent1_org_len_batch[i]: raise ValueError sent1_states = data_tool.padding_tensor( sent1_states, self.arg_dict['max_sentence_length'], align_dir='left', dim=0) sent2_states = self.merge_reps_of_word_pieces( sent2_word_piece_flags, sent2_states) if len(sent2_states) != sent2_org_len_batch[i]: raise ValueError sent2_states = data_tool.padding_tensor( sent2_states, self.arg_dict['max_sentence_length'], align_dir='left', dim=0) sent1_states_batch.append(sent1_states) sent2_states_batch.append(sent2_states) sent1_states_batch = torch.stack(sent1_states_batch, dim=0) sent2_states_batch = torch.stack(sent2_states_batch, dim=0) def get_position_es(shape): position_encodings = general_tool.get_global_position_encodings( length=self.arg_dict['max_sentence_length'], dimension=self.arg_dict['bert_hidden_dim']) position_encodings = position_encodings[:shape[1]] position_encodings = torch.tensor(position_encodings, dtype=self.data_type, device=self.device).expand( [shape[0], -1, -1]) return position_encodings if self.arg_dict['position_encoding']: shape1 = sent1_states_batch.size() position_es1 = get_position_es(shape1) shape2 = sent2_states_batch.size() position_es2 = get_position_es(shape2) sent1_states_batch += position_es1 sent2_states_batch += position_es2 # star_time = time.time() gcn_out1 = self.gcn(sent1_states_batch, adj_matrix1_batch) gcn_out2 = self.gcn(sent2_states_batch, adj_matrix2_batch) if self.arg_dict['concatenate_input_for_gcn_hidden']: gcn_out1 = torch.cat([gcn_out1, sent1_states_batch], dim=2) gcn_out2 = torch.cat([gcn_out2, sent2_states_batch], dim=2) result = self.semantic_layer(gcn_out1, gcn_out2) result = self.fully_connection(result) loss = torch.nn.CrossEntropyLoss()(result.view(-1, 2), labels.view(-1)) predicts = np.array(result.detach().cpu().numpy()).argmax(axis=1) return loss, predicts
class LSSE(fr.Framework): name = "LSSE" result_path = file_tool.connect_path('result', name) def __init__(self, arg_dict): super().__init__(arg_dict) self.name = LSSE.name self.result_path = LSSE.result_path @classmethod def framework_name(cls): return cls.name def create_arg_dict(self): arg_dict = { # 'sgd_momentum': 0.4, 'semantic_compare_func': 'l2', 'concatenate_input_for_gcn_hidden': True, 'fully_scales': [768 * 2, 2], 'position_encoding': True, # 'fully_regular': 1e-4, # 'gcn_regular': 1e-4, # 'bert_regular': 1e-4, 'gcn_layer': 2, 'group_layer_limit_flag': False, # 'group_layer_limit_list': [2, 3, 4, 5, 6], 'gcn_gate_flag': True, 'gcn_norm_item': 0.5, 'gcn_self_loop_flag': True, 'gcn_hidden_dim': 768, 'bert_hidden_dim': 768, 'pad_on_right': True, 'sentence_max_len_for_bert': 128, 'dtype': torch.float32, } return arg_dict def update_arg_dict(self, arg_dict): super().update_arg_dict(arg_dict) if self.arg_dict['concatenate_input_for_gcn_hidden']: self.arg_dict['fully_scales'][0] += self.arg_dict['gcn_hidden_dim'] if self.arg_dict['repeat_train']: time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) if self.arg_dict['group_layer_limit_flag']: gl = self.arg_dict['group_layer_limit_list'] else: gl = self.arg_dict['gcn_layer'] model_dir = file_tool.connect_path( self.result_path, 'train', 'bs:{}-lr:{}-gl:{}--com_fun:{}'.format( self.arg_dict['batch_size'], self.arg_dict['learn_rate'], gl, self.arg_dict['semantic_compare_func']), time_str) else: model_dir = file_tool.connect_path(self.result_path, 'test') file_tool.makedir(model_dir) if not file_tool.check_dir(model_dir): raise RuntimeError self.arg_dict['model_path'] = model_dir def init_weights(self, module): """ Initialize the weights """ if isinstance(module, (torch.nn.Linear)): module.weight.data.normal_(mean=0.0, std=self.bert.config.initializer_range) if isinstance(module, torch.nn.Linear) and module.bias is not None: module.bias.data.zero_() def create_models(self): self.bert = BertBase() self.dropout = torch.nn.Dropout(self.bert.config.hidden_dropout_prob) self.gcn = GCN(self.arg_dict) self.semantic_layer = SemanticLayer(self.arg_dict) self.fully_connection = FullyConnection(self.arg_dict) self.gcn.apply(self.init_weights) self.fully_connection.apply(self.init_weights) def deal_with_example_batch(self, example_ids, example_dict): examples = [example_dict[str(e_id.item())] for e_id in example_ids] sentence_max_len = self.arg_dict['sentence_max_len_for_bert'] pad_on_right = self.arg_dict['pad_on_right'] pad_token = self.bert.tokenizer.convert_tokens_to_ids( [self.bert.tokenizer.pad_token])[0] sep_token = self.bert.tokenizer.convert_tokens_to_ids( [self.bert.tokenizer.sep_token])[0] mask_padding_with_zero = True pad_token_segment_id = 0 sentence1s = [e.sentence1 for e in examples] sentence2s = [e.sentence2 for e in examples] def get_adj_matrix_batch(sentences): adj_matrixs = [] for s in sentences: adj_matrixs.append( parser_tool.dependencies2adj_matrix( s.syntax_info['dependencies'], self.arg_dict['dep_kind_count'], self.arg_dict['max_sentence_length'])) return torch.from_numpy(np.array(adj_matrixs)).to( device=self.device, dtype=self.data_type) adj_matrix1_batch = get_adj_matrix_batch(sentence1s) adj_matrix2_batch = get_adj_matrix_batch(sentence2s) labels = torch.tensor([e.label for e in examples], dtype=torch.long, device=self.device) input_ids_batch = [] token_type_ids_batch = [] attention_mask_batch = [] sep_index_batch = [] sent1_len_batch = [] sent2_len_batch = [] word_piece_flags_batch = [] sent1_org_len_batch = [] sent2_org_len_batch = [] sent1_id_batch = [] for s1, s2 in zip(sentence1s, sentence2s): inputs_ls_cased = self.bert.tokenizer.encode_plus( s1.sentence_with_root_head(), s2.sentence_with_root_head(), add_special_tokens=True, max_length=sentence_max_len, ) input_ids, token_type_ids = inputs_ls_cased[ "input_ids"], inputs_ls_cased["token_type_ids"] word_piece_flags_batch.append( general_tool.word_piece_flag_list( self.bert.tokenizer.convert_ids_to_tokens(input_ids), '##')) attention_mask = [1 if mask_padding_with_zero else 0 ] * len(input_ids) padding_length = sentence_max_len - len(input_ids) if not pad_on_right: input_ids = ([pad_token] * padding_length) + input_ids attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids else: input_ids = input_ids + ([pad_token] * padding_length) attention_mask = attention_mask + ( [0 if mask_padding_with_zero else 1] * padding_length) token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) input_ids_batch.append(input_ids) token_type_ids_batch.append(token_type_ids) attention_mask_batch.append(attention_mask) sep_indexes = [] for sep_index, id_ in enumerate(input_ids.copy()): if id_ == sep_token: sep_indexes.append(sep_index) if len(sep_indexes) != 2: raise ValueError sep_index_batch.append(sep_indexes[0]) sent1_len_batch.append(sep_indexes[0] - 1) sent2_len_batch.append(sep_indexes[1] - sep_indexes[0] - 1) sent1_org_len_batch.append(s1.len_of_tokens()) sent2_org_len_batch.append(s2.len_of_tokens()) sent1_id_batch.append(s1.id) input_ids_batch = torch.tensor(input_ids_batch, device=self.device) token_type_ids_batch = torch.tensor(token_type_ids_batch, device=self.device) attention_mask_batch = torch.tensor(attention_mask_batch, device=self.device) result = { 'input_ids_batch': input_ids_batch, 'token_type_ids_batch': token_type_ids_batch, 'attention_mask_batch': attention_mask_batch, 'sep_index_batch': sep_index_batch, 'word_piece_flags_batch': word_piece_flags_batch, 'sent1_org_len_batch': sent1_org_len_batch, 'sent1_len_batch': sent1_len_batch, 'adj_matrix1_batch': adj_matrix1_batch, 'sent2_org_len_batch': sent2_org_len_batch, 'sent2_len_batch': sent2_len_batch, 'adj_matrix2_batch': adj_matrix2_batch, 'labels': labels, 'sent1_id_batch': sent1_id_batch } return result def merge_reps_of_word_pieces(self, word_piece_flags, token_reps): result_reps = [] word_piece_label = False word_piece_rep = 0 word_piece_count = 0 for i, flag in enumerate(word_piece_flags): if flag == 1: word_piece_rep += token_reps[i] word_piece_count += 1 word_piece_label = True else: if word_piece_label: if word_piece_count == 0: raise ValueError result_reps.append(word_piece_rep / word_piece_count) word_piece_rep = 0 word_piece_count = 0 result_reps.append(token_reps[i]) word_piece_label = False if word_piece_label and (word_piece_count > 0): result_reps.append(word_piece_rep / word_piece_count) result_reps = torch.stack(result_reps, dim=0) return result_reps def forward(self, *input_data, **kwargs): if len(kwargs) > 0: # common run or visualization data_batch = kwargs input_ids_batch = data_batch['input_ids_batch'] token_type_ids_batch = data_batch['token_type_ids_batch'] attention_mask_batch = data_batch['attention_mask_batch'] sep_index_batch = data_batch['sep_index_batch'] word_piece_flags_batch = data_batch['word_piece_flags_batch'] sent1_len_batch = data_batch['sent1_len_batch'] adj_matrix1_batch = data_batch['adj_matrix1_batch'] sent2_len_batch = data_batch['sent2_len_batch'] adj_matrix2_batch = data_batch['adj_matrix2_batch'] labels = data_batch['labels'] # sent1_org_len_batch = data_batch['sent1_org_len_batch'] # sent2_org_len_batch = data_batch['sent2_org_len_batch'] else: input_ids_batch, token_type_ids_batch, attention_mask_batch, sep_index_batch, sent1_len_batch, \ adj_matrix1_batch, sent2_len_batch, adj_matrix2_batch, labels = input_data last_hidden_states_batch, pooled_output = self.bert( input_ids_batch, token_type_ids_batch, attention_mask_batch) pooled_output = self.dropout(pooled_output) sent1_states_batch = [] sent2_states_batch = [] for i, hidden_states in enumerate(last_hidden_states_batch): sent1_word_piece_flags = word_piece_flags_batch[i][ 1:sep_index_batch[i]] sent1_states = hidden_states[1:sep_index_batch[i]] sent2_word_piece_flags = word_piece_flags_batch[ i][sep_index_batch[i] + 1:sep_index_batch[i] + 1 + sent2_len_batch[i]] sent2_states = hidden_states[sep_index_batch[i] + 1:sep_index_batch[i] + 1 + sent2_len_batch[i]] if len(sent1_states) != sent1_len_batch[i] or len( sent2_states) != sent2_len_batch[i]: raise ValueError if len(sent1_states) + len( sent2_states) + 3 != attention_mask_batch[i].sum(): raise ValueError if len(word_piece_flags_batch[i]) != attention_mask_batch[i].sum(): raise ValueError # sent1_states_temp = torch.tensor(sent1_states) sent1_states = self.merge_reps_of_word_pieces( sent1_word_piece_flags, sent1_states) # if len(sent1_states) != sent1_org_len_batch[i]: # raise ValueError sent1_states = data_tool.padding_tensor( sent1_states, self.arg_dict['max_sentence_length'], align_dir='left', dim=0) sent2_states = self.merge_reps_of_word_pieces( sent2_word_piece_flags, sent2_states) # if len(sent2_states) != sent2_org_len_batch[i]: # raise ValueError sent2_states = data_tool.padding_tensor( sent2_states, self.arg_dict['max_sentence_length'], align_dir='left', dim=0) sent1_states_batch.append(sent1_states) sent2_states_batch.append(sent2_states) sent1_states_batch = torch.stack(sent1_states_batch, dim=0) sent2_states_batch = torch.stack(sent2_states_batch, dim=0) def get_position_es(shape): position_encodings = general_tool.get_global_position_encodings( length=self.arg_dict['max_sentence_length'], dimension=self.arg_dict['bert_hidden_dim']) position_encodings = position_encodings[:shape[1]] position_encodings = torch.tensor(position_encodings, dtype=self.data_type, device=self.device).expand( [shape[0], -1, -1]) return position_encodings if self.arg_dict['position_encoding']: shape1 = sent1_states_batch.size() position_es1 = get_position_es(shape1) shape2 = sent2_states_batch.size() position_es2 = get_position_es(shape2) sent1_states_batch += position_es1 sent2_states_batch += position_es2 # star_time = time.time() gcn_out1 = self.gcn(sent1_states_batch, adj_matrix1_batch) gcn_out2 = self.gcn(sent2_states_batch, adj_matrix2_batch) if self.arg_dict['concatenate_input_for_gcn_hidden']: gcn_out1 = torch.cat([gcn_out1, sent1_states_batch], dim=2) gcn_out2 = torch.cat([gcn_out2, sent2_states_batch], dim=2) result = self.semantic_layer(gcn_out1, gcn_out2) result = torch.cat([pooled_output, result], dim=1) result = self.fully_connection(result) loss = torch.nn.CrossEntropyLoss()(result.view(-1, 2), labels.view(-1)) predicts = np.array(result.detach().cpu().numpy()).argmax(axis=1) return loss, predicts def get_regular_parts(self): regular_part_list = (self.gcn, self.fully_connection, self.bert) regular_factor_list = (self.arg_dict['gcn_regular'], self.arg_dict['fully_regular'], self.arg_dict['bert_regular']) return regular_part_list, regular_factor_list def get_input_of_visualize_model(self, example_ids, example_dict): data_batch = self.deal_with_example_batch(example_ids[0:1], example_dict) input_ids_batch = data_batch['input_ids_batch'] token_type_ids_batch = data_batch['token_type_ids_batch'] attention_mask_batch = data_batch['attention_mask_batch'] sep_index_batch = torch.tensor(data_batch['sep_index_batch'], device=self.device) sent1_len_batch = torch.tensor(data_batch['sent1_len_batch'], device=self.device) adj_matrix1_batch = data_batch['adj_matrix1_batch'] sent2_len_batch = torch.tensor(data_batch['sent2_len_batch'], device=self.device) adj_matrix2_batch = data_batch['adj_matrix2_batch'] labels = data_batch['labels'] input_data = (input_ids_batch, token_type_ids_batch, attention_mask_batch, sep_index_batch, sent1_len_batch, adj_matrix1_batch, sent2_len_batch, adj_matrix2_batch, labels) return input_data def count_of_parameter(self): with torch.no_grad(): self.cpu() model_list = [self, self.bert, self.gcn, self.fully_connection] parameter_counts = [] weight_counts = [] bias_counts = [] parameter_list = [] weights_list = [] bias_list = [] for model_ in model_list: parameters_temp = model_.named_parameters() weights_list.clear() parameter_list.clear() bias_list.clear() for name, p in parameters_temp: # print(name) parameter_list.append(p.reshape(-1)) if name.find('weight') != -1: weights_list.append(p.reshape(-1)) if name.find('bias') != -1: bias_list.append(p.reshape(-1)) parameters = torch.cat(parameter_list, dim=0) weights = torch.cat(weights_list, dim=0) biases = torch.cat(bias_list, dim=0) parameter_counts.append(len(parameters)) weight_counts.append(len(weights)) bias_counts.append(len(biases)) for p_count, w_count, b_count in zip(parameter_counts, weight_counts, bias_counts): if p_count != w_count + b_count: raise ValueError for kind in (parameter_counts, weight_counts, bias_counts): total = kind[0] others = kind[1:] count_temp = 0 for other in others: count_temp += other if total != count_temp: raise ValueError self.to(self.device) result = [ { 'name': 'entire', 'total': parameter_counts[0], 'weight': weight_counts[0], 'bias': bias_counts[0] }, { 'name': 'bert', 'total': parameter_counts[1], 'weight': weight_counts[1], 'bias': bias_counts[1] }, { 'name': 'gcn', 'total': parameter_counts[2], 'weight': weight_counts[2], 'bias': bias_counts[2] }, { 'name': 'fully', 'total': parameter_counts[3], 'weight': weight_counts[3], 'bias': bias_counts[3] }, ] return result