示例#1
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        default=False,
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")

    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")

    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")

    args = parser.parse_args()
    args.local_rank = -1
    device = torch.device("cpu")
    n_gpu = 0
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), 'false'))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    task_name = args.task_name.lower()
    
    vecs = []
    vecs.append([0]*100) # CLS
    with open("kg_embed/entity2vec.vec", 'r') as fin:
        for line in fin:
            vec = line.strip().split('\t')
            vec = [float(x) for x in vec]
            vecs.append(vec)
    embed = torch.FloatTensor(vecs)
    embed = torch.nn.Embedding.from_pretrained(embed)
    #embed = torch.nn.Embedding(5041175, 100)

    logger.info("Shape of entity embedding: "+str(embed.weight.size()))
    del vecs

    train_data = None
    num_train_steps = None
    if args.do_train:
        import indexed_dataset
        from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler,BatchSampler
        import iterators
        #train_data = indexed_dataset.IndexedCachedDataset(args.data_dir)
        train_data = indexed_dataset.IndexedDataset(args.data_dir, fix_lua_indexing=True)

        train_sampler = RandomSampler(train_data)
        train_sampler = BatchSampler(train_sampler, args.train_batch_size, True)
        def collate_fn(x):
            x = torch.LongTensor([xx for xx in x])

            entity_idx = x[:, 4*args.max_seq_length:5*args.max_seq_length]
            # Build candidate
            uniq_idx = np.unique(entity_idx.numpy())
            ent_candidate = embed(torch.LongTensor(uniq_idx+1))
            ent_candidate = ent_candidate.repeat([n_gpu, 1])
            # build entity labels
            d = {}
            dd = []
            for i, idx in enumerate(uniq_idx):
                d[idx] = i
                dd.append(idx)
            ent_size = len(uniq_idx)-1
            def map(x):
                if x == -1:
                    return -1
                else:
                    rnd = random.uniform(0, 1)
                    if rnd < 0.05:
                        return dd[random.randint(1, ent_size)]
                    elif rnd < 0.2:
                        return -1
                    else:
                        return x
            ent_labels = entity_idx.clone()
            d[-1] = -1
            ent_labels = ent_labels.apply_(lambda x: d[x])

            entity_idx.apply_(map)
            ent_emb = embed(entity_idx+1)
            mask = entity_idx.clone()
            mask.apply_(lambda x: 0 if x == -1 else 1)
            mask[:,0] = 1

            return x[:,:args.max_seq_length], x[:,args.max_seq_length:2*args.max_seq_length], x[:,2*args.max_seq_length:3*args.max_seq_length], x[:,3*args.max_seq_length:4*args.max_seq_length], ent_emb, mask, x[:,6*args.max_seq_length:], ent_candidate, ent_labels
        train_iterator = iterators.EpochBatchIterator(train_data, collate_fn, train_sampler)
        num_train_steps = int(
            len(train_data) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    model, missing_keys = BertForPreTraining.from_pretrained(args.bert_model,
              cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(-1))

    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_linear = ['layer.2.output.dense_ent', 'layer.2.intermediate.dense_1', 'bert.encoder.layer.2.intermediate.dense_1_ent', 'layer.2.output.LayerNorm_ent']
    no_linear = [x.replace('2', '11') for x in no_linear]
    param_optimizer = [(n, p) for n, p in param_optimizer if not any(nl in n for nl in no_linear)]
    #param_optimizer = [(n, p) for n, p in param_optimizer if not any(nl in n for nl in missing_keys)]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'LayerNorm_ent.bias', 'LayerNorm_ent.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    t_total = num_train_steps
    optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    global_step = 0
    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_data))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)
        model.train()
        import datetime
        fout = open(os.path.join(args.output_dir, "loss.{}".format(datetime.datetime.now())), 'w')
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(tqdm(train_iterator.next_epoch_itr(), desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)

                input_ids, input_mask, segment_ids, masked_lm_labels, input_ent, ent_mask, next_sentence_label, ent_candidate, ent_labels = batch

                loss, original_loss = model(input_ids, segment_ids, input_mask, masked_lm_labels, input_ent, ent_mask, next_sentence_label, ent_candidate, ent_labels)

                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps


                loss.backward()

                fout.write("{} {}\n".format(loss.item()*args.gradient_accumulation_steps, original_loss.item()))
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    lr_this_step = args.learning_rate * warmup_linear(global_step/t_total, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
        fout.close()

    # Save a trained model
    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    torch.save(model_to_save.state_dict(), output_model_file)
示例#2
0
def train(data_obj, dname, args, embed, model):

    data_path = args.data_dir + dname + '_mention_rank'
    local_rep_path = args.local_rep_dir + dname + '_local_rep_mention_rank.npy'
    local_fea_path = args.local_rep_dir + dname + '_local_fea_mention_rank.npy'
    group_path = args.group_path

    mentions, entities, local_feas, ment_names, ment_sents, ment_offsets, ent_ids, mtypes, etypes, pems, labels = \
        data_obj.process_global_data(dname, data_path, local_rep_path, group_path, local_fea_path, args.seq_len, args.candidate_entity_num)

    mention_seq_np, entity_seq_np, local_fea_np, entid_seq_np, pem_seq_np, mtype_seq_np, etype_seq_np, label_seq_np = \
        data_obj.get_local_feature_input(mentions, entities, local_feas, ent_ids, mtypes, etypes, pems, labels, args.seq_len, args.candidate_entity_num)

    seq_tokens_np, seq_tokens_mask_np, seq_tokens_segment_np, seq_ents_np, seq_ents_mask_np, seq_ents_index_np, seq_label_np = \
        data_obj.get_global_feature_input(ment_names, ment_sents, ment_offsets, ent_ids, labels, args.seq_len, args.candidate_entity_num)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    no_grad = [
        'bert.encoder.layer.11.output.dense_ent',
        'bert.encoder.layer.11.output.LayerNorm_ent'
    ]
    param_optimizer = [(n, p) for n, p in param_optimizer
                       if not any(nd in n for nd in no_grad)]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    num_train_steps = int(
        len(seq_tokens_np) / args.train_batch_size /
        args.gradient_accumulation_steps * args.num_train_epochs)
    t_total = num_train_steps
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=t_total)

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(seq_tokens_np))
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Num steps = %d", num_train_steps)

    all_seq_input_id = torch.tensor(seq_tokens_np,
                                    dtype=torch.long)  # (num_example, 256)
    all_seq_input_mask = torch.tensor(seq_tokens_mask_np,
                                      dtype=torch.long)  # (num_example, 256)
    all_seq_segment_id = torch.tensor(seq_tokens_segment_np,
                                      dtype=torch.long)  # (num_example, 256)
    all_seq_input_ent = torch.tensor(seq_ents_np,
                                     dtype=torch.long)  # (num_example, 256)
    all_seq_ent_mask = torch.tensor(seq_ents_mask_np,
                                    dtype=torch.long)  # (num_example, 256)

    all_seq_label = torch.tensor(
        seq_label_np, dtype=torch.long)  # (num_example, 3) # 用于hingeloss
    # all_seq_label = torch.tensor(label_seq_np, dtype=torch.long)     # (num_example, 3, 6) #用于BCEloss

    all_seq_mention_rep = torch.tensor(
        mention_seq_np, dtype=torch.float)  # (num_example, 3, 768)
    all_seq_entity_rep = torch.tensor(
        entity_seq_np, dtype=torch.float)  # (num_example, 3, 6, 768)
    all_seq_entid = torch.tensor(
        entid_seq_np, dtype=torch.long)  #(num_example, 3, 6)  候选实体的eid
    all_seq_ent_index = torch.tensor(
        seq_ents_index_np,
        dtype=torch.long)  # (num_example, 3) eg:[[1,81,141],[],]

    all_seq_pem = torch.tensor(pem_seq_np,
                               dtype=torch.float)  # (num_example, 3, 6)
    all_seq_mtype = torch.tensor(mtype_seq_np,
                                 dtype=torch.float)  #(num_example, 3, 6, 4)
    all_seq_etype = torch.tensor(etype_seq_np,
                                 dtype=torch.float)  # (num_example, 3, 6, 4)
    all_seq_local_fea = torch.tensor(local_fea_np, dtype=torch.float)

    train_data = TensorDataset(all_seq_input_id, all_seq_input_mask, all_seq_segment_id, all_seq_input_ent, \
        all_seq_ent_mask, all_seq_ent_index, all_seq_label, \
        all_seq_mention_rep, all_seq_entity_rep, all_seq_entid, all_seq_pem, all_seq_mtype, all_seq_etype, all_seq_local_fea)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    output_loss_file = os.path.join(args.output_dir, "loss")
    loss_fout = open(output_loss_file, 'w')

    output_f1_file = os.path.join(args.output_dir, "result_f1")
    f1_fout = open(output_f1_file, 'w')
    model.train()

    global_step = 0
    best_f1 = -1
    not_better_count = 0
    for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
        tr_loss, nb_tr_examples, nb_tr_steps = 0, 0, 0
        for batch in tqdm(train_dataloader, desc="Iteration"):
            batch = tuple(
                t.to(device) if i != 3 else t for i, t in enumerate(batch))
            seq_input_id, seq_input_mask, seq_segment_id, seq_input_ent, \
                seq_ent_mask, seq_ent_index, seq_label, \
                    seq_mention_rep, seq_entity_rep, seq_entid, \
                        seq_pem, seq_mtype, seq_etype, seq_local_fea = batch
            seq_input_ent_embed = embed(seq_input_ent + 1).to(device)

            # 加一层seq循环
            # 采样一个周期
            current_input_id_batch = seq_input_id  # shape(batch, ctx_len)
            current_input_mask_batch = seq_input_mask  # shape(b, c)
            current_segment_id_batch = seq_segment_id  # shape(b, c)
            current_input_ent_embed_batch = seq_input_ent_embed  # shape(b, c, dim)
            current_input_ent_batch = seq_input_ent  # shape(b, c)
            current_ent_mask_batch = seq_ent_mask  # shape(b, c)

            for mention_index in range(args.seq_len):
                current_label_batch = seq_label[:, mention_index]  # shape(b,)
                # current_label_batch = seq_label[:, mention_index, :]               # shape(b, 6)
                current_mention_rep_batch = seq_mention_rep[:,
                                                            mention_index, :]  # shape(b, 768)
                current_entity_rep_batch = seq_entity_rep[:,
                                                          mention_index, :, :]  # shape(b, 6, 768)

                current_pem_batch = seq_pem[:, mention_index, :]  # shape(b, 6)
                current_mtype_batch = seq_mtype[:,
                                                mention_index, :, :]  # shape(b, 6, 4)
                current_etype_batch = seq_etype[:,
                                                mention_index, :, :]  # shape(b, 6, 4)
                current_local_fea_batch = seq_local_fea[:, mention_index, :]

                current_entid_batch = seq_entid[:,
                                                mention_index, :]  # shape(b, 6)
                current_ent_index_batch = seq_ent_index[:,
                                                        mention_index]  # shape(b, )
                current_entid_embed_batch = embed(
                    current_entid_batch.cpu() + 1).to(
                        device)  # # shape(b, 6, dim)

                # 训练模型
                loss, scores = \
                    model(current_input_id_batch, current_segment_id_batch, current_input_mask_batch,\
                         current_input_ent_embed_batch, current_ent_mask_batch, current_entid_embed_batch,\
                         current_label_batch, current_mention_rep_batch, current_entity_rep_batch, \
                         current_pem_batch, current_mtype_batch, current_etype_batch, current_local_fea_batch)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                # 根据模型的score值,选择预测的实体,修改current_input_ent 和 current_ent_mask
                current_batch_size = current_input_id_batch.size(0)
                pred_ids = torch.argmax(
                    scores, dim=1)  # shape(b)    scores shape(b, 6)
                pred_ids = pred_ids.reshape(current_batch_size,
                                            1)  # shape(b, 1)

                pred_entid = torch.gather(current_entid_batch, 1,
                                          pred_ids)  # shape(b, 1)
                pred_entmask = torch.ones_like(pred_entid)  # shape(b, 1)

                alter_input_ent_batch = current_input_ent_batch.scatter(1, current_ent_index_batch.reshape(current_batch_size,1).cpu(), \
                    pred_entid.cpu())
                current_input_ent_embed_batch = embed(alter_input_ent_batch +
                                                      1).to(device)
                current_ent_mask_batch.scatter_(1, current_ent_index_batch.reshape(current_batch_size,1), \
                    pred_entmask)

                loss.backward()
                loss_fout.write("{}\n".format(
                    loss.item() * args.gradient_accumulation_steps))

                tr_loss += loss.item()
                nb_tr_examples += current_input_id_batch.size(0)
                nb_tr_steps += 1

                if (nb_tr_steps + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    lr_this_step = args.learning_rate * warmup_linear(
                        global_step / t_total, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if global_step % 100 == 0:
                    print('global_step: ', global_step, 'global_step loss: ',
                          tr_loss / nb_tr_steps)
                    dev_f1 = 0
                    dname_list = [
                        'aida-A', 'aida-B', 'msnbc', 'aquaint', 'ace2004',
                        'clueweb', 'wikipedia'
                    ]

                    for di, dname in enumerate(dname_list):
                        # test model
                        f1 = predict(data_obj, dname, args, embed, model)
                        print(dname, '\033[92m' + 'micro F1: ' + str(f1) +
                              '\033[0m')  # 显色
                        f1_fout.write("{}, f1: {}, step: {}\n".format(
                            dname, f1, global_step))

                        if dname == 'aida-A':
                            dev_f1 = f1
                    if best_f1 < dev_f1:
                        not_better_count = 0
                        best_f1 = dev_f1
                        print('save best model ...')
                        output_model_file = os.path.join(
                            args.output_dir,
                            "pytorch_model_nolocal_{}.bin".format(global_step))
                        torch.save(model.state_dict(), output_model_file)
                    else:
                        not_better_count += 1
                if not_better_count > 3:  # 早停
                    exit(0)