def main(): parser = argparse.ArgumentParser() ## Required parameters parser.add_argument("--data_dir", default=None, type=str, required=True, help="The input data dir. Should contain the .tsv files (or other data files) for the task.") parser.add_argument("--bert_model", default=None, type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") parser.add_argument("--task_name", default=None, type=str, required=True, help="The name of the task to train.") parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.") ## Other parameters parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_train", default=False, action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", default=False, action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--do_lower_case", default=False, action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=8, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument('--loss_scale', type=float, default=0, help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") args = parser.parse_args() args.local_rank = -1 device = torch.device("cpu") n_gpu = 0 logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( device, n_gpu, bool(args.local_rank != -1), 'false')) if args.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if not args.do_train and not args.do_eval: raise ValueError("At least one of `do_train` or `do_eval` must be True.") if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) os.makedirs(args.output_dir, exist_ok=True) task_name = args.task_name.lower() vecs = [] vecs.append([0]*100) # CLS with open("kg_embed/entity2vec.vec", 'r') as fin: for line in fin: vec = line.strip().split('\t') vec = [float(x) for x in vec] vecs.append(vec) embed = torch.FloatTensor(vecs) embed = torch.nn.Embedding.from_pretrained(embed) #embed = torch.nn.Embedding(5041175, 100) logger.info("Shape of entity embedding: "+str(embed.weight.size())) del vecs train_data = None num_train_steps = None if args.do_train: import indexed_dataset from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler,BatchSampler import iterators #train_data = indexed_dataset.IndexedCachedDataset(args.data_dir) train_data = indexed_dataset.IndexedDataset(args.data_dir, fix_lua_indexing=True) train_sampler = RandomSampler(train_data) train_sampler = BatchSampler(train_sampler, args.train_batch_size, True) def collate_fn(x): x = torch.LongTensor([xx for xx in x]) entity_idx = x[:, 4*args.max_seq_length:5*args.max_seq_length] # Build candidate uniq_idx = np.unique(entity_idx.numpy()) ent_candidate = embed(torch.LongTensor(uniq_idx+1)) ent_candidate = ent_candidate.repeat([n_gpu, 1]) # build entity labels d = {} dd = [] for i, idx in enumerate(uniq_idx): d[idx] = i dd.append(idx) ent_size = len(uniq_idx)-1 def map(x): if x == -1: return -1 else: rnd = random.uniform(0, 1) if rnd < 0.05: return dd[random.randint(1, ent_size)] elif rnd < 0.2: return -1 else: return x ent_labels = entity_idx.clone() d[-1] = -1 ent_labels = ent_labels.apply_(lambda x: d[x]) entity_idx.apply_(map) ent_emb = embed(entity_idx+1) mask = entity_idx.clone() mask.apply_(lambda x: 0 if x == -1 else 1) mask[:,0] = 1 return x[:,:args.max_seq_length], x[:,args.max_seq_length:2*args.max_seq_length], x[:,2*args.max_seq_length:3*args.max_seq_length], x[:,3*args.max_seq_length:4*args.max_seq_length], ent_emb, mask, x[:,6*args.max_seq_length:], ent_candidate, ent_labels train_iterator = iterators.EpochBatchIterator(train_data, collate_fn, train_sampler) num_train_steps = int( len(train_data) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) # Prepare model model, missing_keys = BertForPreTraining.from_pretrained(args.bert_model, cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(-1)) model.to(device) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_linear = ['layer.2.output.dense_ent', 'layer.2.intermediate.dense_1', 'bert.encoder.layer.2.intermediate.dense_1_ent', 'layer.2.output.LayerNorm_ent'] no_linear = [x.replace('2', '11') for x in no_linear] param_optimizer = [(n, p) for n, p in param_optimizer if not any(nl in n for nl in no_linear)] #param_optimizer = [(n, p) for n, p in param_optimizer if not any(nl in n for nl in missing_keys)] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'LayerNorm_ent.bias', 'LayerNorm_ent.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] t_total = num_train_steps optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=t_total) global_step = 0 if args.do_train: logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_data)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_steps) model.train() import datetime fout = open(os.path.join(args.output_dir, "loss.{}".format(datetime.datetime.now())), 'w') for _ in trange(int(args.num_train_epochs), desc="Epoch"): tr_loss = 0 nb_tr_examples, nb_tr_steps = 0, 0 for step, batch in enumerate(tqdm(train_iterator.next_epoch_itr(), desc="Iteration")): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, masked_lm_labels, input_ent, ent_mask, next_sentence_label, ent_candidate, ent_labels = batch loss, original_loss = model(input_ids, segment_ids, input_mask, masked_lm_labels, input_ent, ent_mask, next_sentence_label, ent_candidate, ent_labels) if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() fout.write("{} {}\n".format(loss.item()*args.gradient_accumulation_steps, original_loss.item())) tr_loss += loss.item() nb_tr_examples += input_ids.size(0) nb_tr_steps += 1 if (step + 1) % args.gradient_accumulation_steps == 0: # modify learning rate with special warm up BERT uses lr_this_step = args.learning_rate * warmup_linear(global_step/t_total, args.warmup_proportion) for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 fout.close() # Save a trained model model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") torch.save(model_to_save.state_dict(), output_model_file)
def train(data_obj, dname, args, embed, model): data_path = args.data_dir + dname + '_mention_rank' local_rep_path = args.local_rep_dir + dname + '_local_rep_mention_rank.npy' local_fea_path = args.local_rep_dir + dname + '_local_fea_mention_rank.npy' group_path = args.group_path mentions, entities, local_feas, ment_names, ment_sents, ment_offsets, ent_ids, mtypes, etypes, pems, labels = \ data_obj.process_global_data(dname, data_path, local_rep_path, group_path, local_fea_path, args.seq_len, args.candidate_entity_num) mention_seq_np, entity_seq_np, local_fea_np, entid_seq_np, pem_seq_np, mtype_seq_np, etype_seq_np, label_seq_np = \ data_obj.get_local_feature_input(mentions, entities, local_feas, ent_ids, mtypes, etypes, pems, labels, args.seq_len, args.candidate_entity_num) seq_tokens_np, seq_tokens_mask_np, seq_tokens_segment_np, seq_ents_np, seq_ents_mask_np, seq_ents_index_np, seq_label_np = \ data_obj.get_global_feature_input(ment_names, ment_sents, ment_offsets, ent_ids, labels, args.seq_len, args.candidate_entity_num) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_grad = [ 'bert.encoder.layer.11.output.dense_ent', 'bert.encoder.layer.11.output.LayerNorm_ent' ] param_optimizer = [(n, p) for n, p in param_optimizer if not any(nd in n for nd in no_grad)] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] num_train_steps = int( len(seq_tokens_np) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) t_total = num_train_steps optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=t_total) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(seq_tokens_np)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_steps) all_seq_input_id = torch.tensor(seq_tokens_np, dtype=torch.long) # (num_example, 256) all_seq_input_mask = torch.tensor(seq_tokens_mask_np, dtype=torch.long) # (num_example, 256) all_seq_segment_id = torch.tensor(seq_tokens_segment_np, dtype=torch.long) # (num_example, 256) all_seq_input_ent = torch.tensor(seq_ents_np, dtype=torch.long) # (num_example, 256) all_seq_ent_mask = torch.tensor(seq_ents_mask_np, dtype=torch.long) # (num_example, 256) all_seq_label = torch.tensor( seq_label_np, dtype=torch.long) # (num_example, 3) # 用于hingeloss # all_seq_label = torch.tensor(label_seq_np, dtype=torch.long) # (num_example, 3, 6) #用于BCEloss all_seq_mention_rep = torch.tensor( mention_seq_np, dtype=torch.float) # (num_example, 3, 768) all_seq_entity_rep = torch.tensor( entity_seq_np, dtype=torch.float) # (num_example, 3, 6, 768) all_seq_entid = torch.tensor( entid_seq_np, dtype=torch.long) #(num_example, 3, 6) 候选实体的eid all_seq_ent_index = torch.tensor( seq_ents_index_np, dtype=torch.long) # (num_example, 3) eg:[[1,81,141],[],] all_seq_pem = torch.tensor(pem_seq_np, dtype=torch.float) # (num_example, 3, 6) all_seq_mtype = torch.tensor(mtype_seq_np, dtype=torch.float) #(num_example, 3, 6, 4) all_seq_etype = torch.tensor(etype_seq_np, dtype=torch.float) # (num_example, 3, 6, 4) all_seq_local_fea = torch.tensor(local_fea_np, dtype=torch.float) train_data = TensorDataset(all_seq_input_id, all_seq_input_mask, all_seq_segment_id, all_seq_input_ent, \ all_seq_ent_mask, all_seq_ent_index, all_seq_label, \ all_seq_mention_rep, all_seq_entity_rep, all_seq_entid, all_seq_pem, all_seq_mtype, all_seq_etype, all_seq_local_fea) train_sampler = RandomSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) output_loss_file = os.path.join(args.output_dir, "loss") loss_fout = open(output_loss_file, 'w') output_f1_file = os.path.join(args.output_dir, "result_f1") f1_fout = open(output_f1_file, 'w') model.train() global_step = 0 best_f1 = -1 not_better_count = 0 for epoch in trange(int(args.num_train_epochs), desc="Epoch"): tr_loss, nb_tr_examples, nb_tr_steps = 0, 0, 0 for batch in tqdm(train_dataloader, desc="Iteration"): batch = tuple( t.to(device) if i != 3 else t for i, t in enumerate(batch)) seq_input_id, seq_input_mask, seq_segment_id, seq_input_ent, \ seq_ent_mask, seq_ent_index, seq_label, \ seq_mention_rep, seq_entity_rep, seq_entid, \ seq_pem, seq_mtype, seq_etype, seq_local_fea = batch seq_input_ent_embed = embed(seq_input_ent + 1).to(device) # 加一层seq循环 # 采样一个周期 current_input_id_batch = seq_input_id # shape(batch, ctx_len) current_input_mask_batch = seq_input_mask # shape(b, c) current_segment_id_batch = seq_segment_id # shape(b, c) current_input_ent_embed_batch = seq_input_ent_embed # shape(b, c, dim) current_input_ent_batch = seq_input_ent # shape(b, c) current_ent_mask_batch = seq_ent_mask # shape(b, c) for mention_index in range(args.seq_len): current_label_batch = seq_label[:, mention_index] # shape(b,) # current_label_batch = seq_label[:, mention_index, :] # shape(b, 6) current_mention_rep_batch = seq_mention_rep[:, mention_index, :] # shape(b, 768) current_entity_rep_batch = seq_entity_rep[:, mention_index, :, :] # shape(b, 6, 768) current_pem_batch = seq_pem[:, mention_index, :] # shape(b, 6) current_mtype_batch = seq_mtype[:, mention_index, :, :] # shape(b, 6, 4) current_etype_batch = seq_etype[:, mention_index, :, :] # shape(b, 6, 4) current_local_fea_batch = seq_local_fea[:, mention_index, :] current_entid_batch = seq_entid[:, mention_index, :] # shape(b, 6) current_ent_index_batch = seq_ent_index[:, mention_index] # shape(b, ) current_entid_embed_batch = embed( current_entid_batch.cpu() + 1).to( device) # # shape(b, 6, dim) # 训练模型 loss, scores = \ model(current_input_id_batch, current_segment_id_batch, current_input_mask_batch,\ current_input_ent_embed_batch, current_ent_mask_batch, current_entid_embed_batch,\ current_label_batch, current_mention_rep_batch, current_entity_rep_batch, \ current_pem_batch, current_mtype_batch, current_etype_batch, current_local_fea_batch) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps # 根据模型的score值,选择预测的实体,修改current_input_ent 和 current_ent_mask current_batch_size = current_input_id_batch.size(0) pred_ids = torch.argmax( scores, dim=1) # shape(b) scores shape(b, 6) pred_ids = pred_ids.reshape(current_batch_size, 1) # shape(b, 1) pred_entid = torch.gather(current_entid_batch, 1, pred_ids) # shape(b, 1) pred_entmask = torch.ones_like(pred_entid) # shape(b, 1) alter_input_ent_batch = current_input_ent_batch.scatter(1, current_ent_index_batch.reshape(current_batch_size,1).cpu(), \ pred_entid.cpu()) current_input_ent_embed_batch = embed(alter_input_ent_batch + 1).to(device) current_ent_mask_batch.scatter_(1, current_ent_index_batch.reshape(current_batch_size,1), \ pred_entmask) loss.backward() loss_fout.write("{}\n".format( loss.item() * args.gradient_accumulation_steps)) tr_loss += loss.item() nb_tr_examples += current_input_id_batch.size(0) nb_tr_steps += 1 if (nb_tr_steps + 1) % args.gradient_accumulation_steps == 0: # modify learning rate with special warm up BERT uses lr_this_step = args.learning_rate * warmup_linear( global_step / t_total, args.warmup_proportion) for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 if global_step % 100 == 0: print('global_step: ', global_step, 'global_step loss: ', tr_loss / nb_tr_steps) dev_f1 = 0 dname_list = [ 'aida-A', 'aida-B', 'msnbc', 'aquaint', 'ace2004', 'clueweb', 'wikipedia' ] for di, dname in enumerate(dname_list): # test model f1 = predict(data_obj, dname, args, embed, model) print(dname, '\033[92m' + 'micro F1: ' + str(f1) + '\033[0m') # 显色 f1_fout.write("{}, f1: {}, step: {}\n".format( dname, f1, global_step)) if dname == 'aida-A': dev_f1 = f1 if best_f1 < dev_f1: not_better_count = 0 best_f1 = dev_f1 print('save best model ...') output_model_file = os.path.join( args.output_dir, "pytorch_model_nolocal_{}.bin".format(global_step)) torch.save(model.state_dict(), output_model_file) else: not_better_count += 1 if not_better_count > 3: # 早停 exit(0)