#
# tokenizer.decode(labels.squeeze(0))

pred = model(inputs)
pred.shape

tokenizer.decode(torch.argmax(pred, dim=-1).squeeze(0))

loss_fn = nn.CrossEntropyLoss()  #

masked_lm_loss = loss_fn(pred.view(-1, tokenizer.vocab_size), labels.view(-1))
masked_lm_loss

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
total_loss = 0.0
model.train()

model.to(device)
inputs = inputs.to(device)
labels = labels.to(device)

loss = []
optimizer = AdamW(params=model.parameters())

for _ in tqdm(range(100000)):
    pred = model(inputs)
    mlm_loss = loss_fn(pred.view(-1, tokenizer.vocab_size), labels.view(-1))

    total_loss += mlm_loss.item()
    loss.append(mlm_loss.item())
Exemple #2
0
def main():
    cmd_args = add_argument()

    path_to_file_tr = cmd_args.path_to_file_tr
    path_to_file_ts = cmd_args.path_to_file_ts

    min_len_mol = cmd_args.min_len_mol
    max_len_mol = cmd_args.max_len_mol

    num_examples_tr = cmd_args.num_examples_tr
    num_examples_ts = cmd_args.num_examples_ts

    train_batch_size = json.load(open(cmd_args.ds_conf))['train_batch_size']
    gradient_accumulation_steps = json.load(open(
        cmd_args.ds_conf))['gradient_accumulation_steps']

    deepspeed_optimizer = True if json.load(open(cmd_args.ds_conf)).get(
        'optimizer', None) is not None else False

    epochs = cmd_args.epochs
    emb_dim = cmd_args.emb_dim
    dim = cmd_args.dim
    bucket_size = cmd_args.bucket_size
    depth = cmd_args.depth
    heads = cmd_args.heads
    n_hashes = cmd_args.n_hashes
    ff_chunks = cmd_args.ff_chunks
    attn_chunks = cmd_args.attn_chunks
    validate_every = cmd_args.validate_every
    save_every = cmd_args.save_every
    output_folder = cmd_args.output_folder

    use_full_attn = cmd_args.use_full_attn
    mrpc_test = cmd_args.mrpc_test
    use_deepspeed = cmd_args.use_deepspeed

    os.makedirs(output_folder, exist_ok=True)

    pickle.dump(cmd_args,
                open(os.sep.join([output_folder, 'training_conf.pkl']), 'wb'))

    MIN_LENGTH_MOL = min_len_mol
    MAX_LENGTH_MOL = max_len_mol  # 2048
    NUM_EXAMPLES_TR = num_examples_tr  # 1024
    NUM_EXAMPLES_TS = num_examples_ts  # 1024
    N_EPOCHS = epochs  # 10
    VALIDATE_EVERY = validate_every
    SAVE_EVERY = save_every

    MOL_SEQ_LEN = MAX_LENGTH_MOL  # output_lang.max_len if (output_lang.max_len % 2) == 0  else output_lang.max_len + 1 # ??

    saved_mol_lang = os.sep.join([output_folder, 'mol_lang.pkl'])

    MAX_LENGTH_MOL = cmd_args.max_len_mol

    saved_target_lang = os.sep.join([output_folder, 'mol_lang.pkl'])

    if mrpc_test:
        mol_lang, tr_samples, ts_samples = readMRPC(
            molecule_file_tr=path_to_file_tr,
            molecule_file_ts=path_to_file_ts,
            saved_molecule_lang=saved_target_lang,
            num_examples_tr=NUM_EXAMPLES_TR,
            num_examples_ts=NUM_EXAMPLES_TS,
            min_len_molecule=MIN_LENGTH_MOL,
            max_len_molecule=MAX_LENGTH_MOL,
            shuffle=True)
    else:
        mol_lang, tr_samples, ts_samples = readMolecules(
            molecule_file_tr=path_to_file_tr,
            molecule_file_ts=path_to_file_ts,
            saved_molecule_lang=saved_target_lang,
            num_examples_tr=NUM_EXAMPLES_TR,
            num_examples_ts=NUM_EXAMPLES_TS,
            min_len_molecule=MIN_LENGTH_MOL,
            max_len_molecule=MAX_LENGTH_MOL,
            shuffle=True)

    pickle.dump(mol_lang, open(saved_mol_lang, 'wb'))

    train_dataset = MolecularSimilarityDataset(
        tr_samples, mol_lang, train_batch_size if device == 'cuda' else 1)
    test_dataset = MolecularSimilarityDataset(
        ts_samples, mol_lang, train_batch_size if device == 'cuda' else 1)

    MAX_SEQ_LEN = MOL_SEQ_LEN * 2
    print('Axial Embedding shape:', compute_axial_position_shape(MAX_SEQ_LEN))
    model = ReformerLM(
        num_tokens=mol_lang.n_words,
        dim=dim,
        bucket_size=bucket_size,
        depth=depth,
        heads=heads,
        n_hashes=n_hashes,
        max_seq_len=MAX_SEQ_LEN,
        ff_chunks=ff_chunks,
        attn_chunks=attn_chunks,
        weight_tie=True,
        weight_tie_embedding=True,
        axial_position_emb=True,
        axial_position_shape=compute_axial_position_shape(MAX_SEQ_LEN),
        axial_position_dims=(dim // 2, dim // 2),
        return_embeddings=True,
        use_full_attn=use_full_attn).to(device)

    linear_regressor = Linear(512, 2).to(device)

    model = TrainingWrapper(model, ignore_index=PAD_IDX,
                            pad_value=PAD_IDX).to(device)

    model_params = filter(lambda p: p.requires_grad, model.parameters())
    linear_params = filter(lambda p: p.requires_grad,
                           linear_regressor.parameters())

    SAVE_DIR = os.sep.join([output_folder, 'saved_model'])
    os.makedirs(SAVE_DIR, exist_ok=True)

    try:
        model_ckp_max = np.max(
            [int(ckp) for ckp in os.listdir(os.sep.join([SAVE_DIR, 'model']))])
    except:
        model_ckp_max = 0

    gpus_mini_batch = (train_batch_size // gradient_accumulation_steps
                       ) // torch.cuda.device_count()
    print('gpus_mini_batch:', gpus_mini_batch,
          'with gradient_accumulation_steps:', gradient_accumulation_steps)
    log_file = open(os.sep.join([output_folder, 'training_log.log']), 'a')
    log_file.write(
        "\n\n\n{}\tStarting new training from chekpoint: EncoderDecoder-{}\n".
        format(datetime.datetime.now(), model_ckp_max))
    log_file.flush()

    if use_deepspeed:
        if deepspeed_optimizer == False:
            print('No DeepSpeed optimizer found. Using RangerLars.')
            model_optimizer = RangerLars(model.parameters())
            linear_optimizer = RangerLars(linear_regressor.parameters())

            model_engine, model_optimizer, trainloader, _ = deepspeed.initialize(
                args=cmd_args,
                model=model,
                optimizer=model_optimizer,
                model_parameters=model_params,
                training_data=train_dataset)

            linear_engine, linear_optimizer, _, _ = deepspeed.initialize(
                args=cmd_args,
                model=linear_regressor,
                optimizer=linear_optimizer,
                model_parameters=linear_params)

        else:
            print('Found optimizer in the DeepSpeed configurations. Using it.')
            model_engine, model_optimizer, trainloader, _ = deepspeed.initialize(
                args=cmd_args,
                model=model,
                model_parameters=model_params,
                training_data=train_dataset)
            linear_engine, linear_optimizer, _, _ = deepspeed.initialize(
                args=cmd_args,
                model=linear_regressor,
                model_parameters=linear_params)

        _, model_client_sd = model_engine.load_checkpoint(
            os.sep.join([SAVE_DIR, 'model']), model_ckp_max)

        testloader = model_engine.deepspeed_io(test_dataset)

        ######TO DO
        for eph in range(epochs):
            print('Starting Epoch: {}'.format(eph))
            for i, pair in enumerate(tqdm(trainloader)):
                tr_step = ((eph * len(trainloader)) + i) + 1

                src = pair[0]
                trg = pair[1]

                pickle.dump(src, open('src.pkl', 'wb'))
                pickle.dump(trg, open('trg.pkl', 'wb'))

                model_engine.train()
                linear_engine.train()
                #enc_dec.train()

                src = src.to(model_engine.local_rank)
                trg = trg.to(linear_engine.local_rank)

                print("Sample:", src)
                print("Target:", trg)
                print("Target Shape:", trg.shape)
                print("len Samples:", len(src))

                ## Need to learn how to use masks correctly
                enc_input_mask = torch.tensor(
                    [[1 if idx != PAD_IDX else 0 for idx in smpl]
                     for smpl in src]).bool().to(model_engine.local_rank)

                # context_mask = torch.tensor([[1 for idx in smpl if idx != PAD_IDX] for smpl in trg]).bool().to(device)
                #################

                enc_keys = model_engine(
                    src, return_loss=False, input_mask=enc_input_mask
                )  #enc_input_mask)#, context_mask=context_mask)
                #loss = enc_dec(src, trg, return_loss = True, enc_input_mask = None)#enc_input_mask)#, context_mask=context_mask)

                print('enc_keys shape', enc_keys.shape)
                #enc_keys_cls = enc_keys[:,0:1,:].to(linear_engine.local_rank)#torch.tensor([s[0] for s in enc_keys]).to(linear_engine.local_rank)
                #print('enc_keys_cls shape', enc_keys_cls.shape)
                preds = torch.softmax(linear_engine(enc_keys),
                                      dim=1).to(linear_engine.local_rank)

                print('preds shape', preds.shape)
                #preds = np.array([r[0] for r in results])
                #print('Pred:', preds.shape)
                loss = F.cross_entropy(preds, trg).to(linear_engine.local_rank)
                loss.backward()

                model_engine.step()
                linear_engine.step()

                print('Training Loss:', loss.item())
                if tr_step % validate_every == 0:
                    val_loss = []
                    for pair in tqdm(
                            testloader
                    ):  #Can't use the testloader or I will mess up with the model assignment and it won't learn during training, need to use normal validation instead of parallel one
                        model_engine.eval()
                        linear_engine.eval()
                        with torch.no_grad():
                            ts_src = pair[0]
                            ts_trg = pair[1]

                            pickle.dump(ts_src, open('ts_src.pkl', 'wb'))
                            pickle.dump(ts_trg, open('ts_trg.pkl', 'wb'))

                            ts_src = ts_src.to(model_engine.local_rank)
                            ts_trg = ts_trg.to(linear_engine.local_rank)

                            #ts_src = torch.tensor(np.array([pair[0].numpy()])).to(device)
                            #ts_trg = torch.tensor(np.array([pair[1].numpy()])).to(device)

                            ## Need to learn how to use masks correctly
                            ts_enc_input_mask = torch.tensor([
                                [1 if idx != PAD_IDX else 0 for idx in smpl]
                                for smpl in ts_src
                            ]).bool().to(model_engine.local_rank)
                            #ts_context_mask = torch.tensor([[1 for idx in smpl if idx != PAD_IDX] for smpl in ts_trg]).bool().to(device)

                            # loss = model_engine(
                            #     ts_src,
                            #     ts_trg,
                            #     return_loss=True,
                            #     enc_input_mask=ts_enc_input_mask
                            # )  #ts_enc_input_mask)#, context_mask=ts_context_mask)
                            # #loss = enc_dec(ts_src, ts_trg, return_loss = True, enc_input_mask = None)

                            ts_enc_keys = model_engine(
                                ts_src,
                                return_loss=False,
                                input_mask=ts_enc_input_mask)
                            ts_pred = torch.softmax(
                                linear_engine(ts_enc_keys),
                                dim=1).to(linear_engine.local_rank)
                            loss = F.cross_entropy(ts_pred, ts_trg).to(
                                linear_engine.local_rank)
                            val_loss.append(loss.item())

                    print(
                        f'\tValidation Loss: AVG: {np.mean(val_loss)}, MEDIAN: {np.median(val_loss)}, STD: {np.std(val_loss)} '
                    )
                    log_file.write(
                        'Step: {}\tTraining Loss:{}\t Validation LOSS: AVG: {}| MEDIAN: {}| STD: {}\n'
                        .format(i, loss.item(), np.mean(val_loss),
                                np.median(val_loss), np.std(val_loss)))
                else:
                    log_file.write('Step: {}\tTraining Loss:{}\n'.format(
                        i, loss.item()))

                log_file.flush()

                if tr_step % save_every == 0:
                    print('\tSaving Checkpoint')
                    model_ckpt_id = str(model_ckp_max + tr_step + 1)
                    model_engine.save_checkpoint(
                        os.sep.join([SAVE_DIR, 'model']), model_ckpt_id)

        log_file.close()
        print('\tSaving Final Checkpoint')
        model_ckpt_id = str(model_ckp_max + tr_step + 1)
        model_engine.save_checkpoint(os.sep.join([SAVE_DIR, 'model']),
                                     model_ckpt_id)
    else:
        #model_optimizer = torch.optim.Adam(model.parameters()) # RangerLars(model.parameters())
        #linear_optimizer = torch.optim.Adam(linear_regressor.parameters())  # RangerLars(linear_regressor.parameters())

        model_optimizer = torch.optim.Adam(
            list(model.parameters()) + list(linear_regressor.parameters())
        )  #RangerLars(list(model.parameters())+list(linear_regressor.parameters())) #

        PATH = os.sep.join(
            [SAVE_DIR, 'model',
             str(model_ckp_max), 'sts_model.pt'])
        if os.path.exists(PATH):
            print('********** Found Checkpoint. Loading:', PATH)
            checkpoint = torch.load(PATH)

            model.load_state_dict(checkpoint['model_state_dict'])
            linear_regressor.load_state_dict(checkpoint['linear_state_dict'])
            model_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        trainloader = DataLoader(train_dataset,
                                 batch_size=train_batch_size,
                                 shuffle=False)
        testloader = DataLoader(test_dataset,
                                batch_size=train_batch_size,
                                shuffle=False)
        ######TO DO
        train_loss_list = []
        for eph in range(epochs):
            print('Starting Epoch: {}'.format(eph))
            for i, pair in enumerate(tqdm(trainloader)):
                tr_step = ((eph * len(trainloader)) + i) + 1

                src = pair[0]
                trg = pair[1]

                pickle.dump(src, open('src.pkl', 'wb'))
                pickle.dump(trg, open('trg.pkl', 'wb'))

                model.train()
                linear_regressor.train()
                #enc_dec.train()

                src = src.to(device)
                trg = trg.to(device)

                #print("Sample:", src)
                #print("Target:", trg)
                #print("Target Shape:", trg.shape)
                #print("len Samples:", len(src))

                ## Need to learn how to use masks correctly
                enc_input_mask = torch.tensor(
                    [[1 if idx != PAD_IDX else 0 for idx in smpl]
                     for smpl in src]).bool().to(device)

                # context_mask = torch.tensor([[1 for idx in smpl if idx != PAD_IDX] for smpl in trg]).bool().to(device)
                #################

                enc_keys = model(
                    src, return_loss=False, input_mask=enc_input_mask
                )  #enc_input_mask)#, context_mask=context_mask)
                #loss = enc_dec(src, trg, return_loss = True, enc_input_mask = None)#enc_input_mask)#, context_mask=context_mask)

                #print('enc_keys shape', enc_keys.shape)
                enc_keys_cls = enc_keys[:, 0, :].to(
                    device
                )  #torch.tensor([s[0] for s in enc_keys]).to(linear_engine.local_rank)
                #print('enc_keys_cls shape', enc_keys_cls.shape)
                preds = torch.softmax(linear_regressor(enc_keys_cls),
                                      dim=1).to(device)

                #print('preds shape', preds.shape)
                #preds = np.array([r[0] for r in results])
                #print('Pred:', preds.shape)
                loss = F.cross_entropy(preds, trg).to(device)
                loss.backward()

                model_optimizer.step()
                #linear_optimizer.step()

                train_loss_list.append(loss.item())
                #print('Training Loss:', loss.item())
                if tr_step % validate_every == 0:
                    val_loss = []
                    ACC_list = []
                    MCC_list = []
                    for pair in tqdm(
                            testloader
                    ):  #Can't use the testloader or I will mess up with the model assignment and it won't learn during training, need to use normal validation instead of parallel one
                        model.eval()
                        linear_regressor.eval()
                        with torch.no_grad():
                            ts_src = pair[0]
                            ts_trg = pair[1]

                            pickle.dump(ts_src, open('ts_src.pkl', 'wb'))
                            pickle.dump(ts_trg, open('ts_trg.pkl', 'wb'))

                            ts_src = ts_src.to(device)
                            ts_trg = ts_trg.to(device)

                            #ts_src = torch.tensor(np.array([pair[0].numpy()])).to(device)
                            #ts_trg = torch.tensor(np.array([pair[1].numpy()])).to(device)

                            ## Need to learn how to use masks correctly
                            ts_enc_input_mask = torch.tensor(
                                [[1 if idx != PAD_IDX else 0 for idx in smpl]
                                 for smpl in ts_src]).bool().to(device)
                            #ts_context_mask = torch.tensor([[1 for idx in smpl if idx != PAD_IDX] for smpl in ts_trg]).bool().to(device)

                            # loss = model_engine(
                            #     ts_src,
                            #     ts_trg,
                            #     return_loss=True,
                            #     enc_input_mask=ts_enc_input_mask
                            # )  #ts_enc_input_mask)#, context_mask=ts_context_mask)
                            # #loss = enc_dec(ts_src, ts_trg, return_loss = True, enc_input_mask = None)

                            ts_enc_keys = model(ts_src,
                                                return_loss=False,
                                                input_mask=ts_enc_input_mask)
                            ts_enc_keys_cls = ts_enc_keys[:, 0, :].to(device)

                            ts_pred = torch.softmax(
                                linear_regressor(ts_enc_keys_cls),
                                dim=1).to(device)

                            loss = F.cross_entropy(ts_pred, ts_trg).to(device)

                            ACC, MCC = compute_simple_metrics(ts_pred, ts_trg)
                            ACC_list.append(ACC)
                            MCC_list.append(MCC)

                            val_loss.append(loss.item())

                    print(
                        f'\Train Loss: LAST: {train_loss_list[-1]}, AVG: {np.mean(train_loss_list)}, MEDIAN: {np.median(train_loss_list)}, STD: {np.std(train_loss_list)} '
                    )
                    print(
                        f'\tValidation Loss: AVG: {np.mean(val_loss)}, MEDIAN: {np.median(val_loss)}, STD: {np.std(val_loss)} '
                    )
                    print(
                        f'\tValidation ACC: AVG: {np.mean(ACC_list)}, MEDIAN: {np.median(ACC_list)}, STD: {np.std(ACC_list)} '
                    )
                    print(
                        f'\tValidation MCC: AVG: {np.mean(MCC_list)}, MEDIAN: {np.median(MCC_list)}, STD: {np.std(MCC_list)} '
                    )
                    log_file.write(
                        'Step: {}\tTraining Loss:{}\t Validation LOSS: AVG: {}| MEDIAN: {}| STD: {}\n'
                        .format(i, loss.item(), np.mean(val_loss),
                                np.median(val_loss), np.std(val_loss)))
                else:
                    log_file.write('Step: {}\tTraining Loss:{}\n'.format(
                        i, loss.item()))

                log_file.flush()

                if tr_step % save_every == 0:
                    print('\tSaving Checkpoint')
                    model_ckpt_id = str(model_ckp_max + tr_step + 1)
                    #model_engine.save_checkpoint(os.sep.join([SAVE_DIR, 'model']),
                    #                            model_ckpt_id)
                    PATH = os.sep.join([
                        SAVE_DIR, 'model',
                        str(model_ckpt_id), 'sts_model.pt'
                    ])
                    os.makedirs(os.sep.join(PATH.split(os.sep)[:-1]),
                                exist_ok=True)
                    torch.save(
                        {
                            'step': tr_step,
                            'model_state_dict': model.state_dict(),
                            'linear_state_dict': linear_regressor.state_dict(),
                            'optimizer_state_dict':
                            model_optimizer.state_dict(),
                        }, PATH)

        log_file.close()
        print('\tSaving Final Checkpoint')
        model_ckpt_id = str(model_ckp_max + tr_step + 1)
        #model_engine.save_checkpoint(os.sep.join([SAVE_DIR, 'model']),
        #                            model_ckpt_id)
        PATH = os.sep.join(
            [SAVE_DIR, 'model',
             str(model_ckpt_id), 'sts_model.pt'])
        os.makedirs(os.sep.join(PATH.split(os.sep)[:-1]), exist_ok=True)
        torch.save(
            {
                'step': tr_step,
                'model_state_dict': model.state_dict(),
                'linear_state_dict': linear_regressor.state_dict(),
                'optimizer_state_dict': model_optimizer.state_dict(),
            }, PATH)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default='cuda', type=str, required=False, help='设置使用哪些显卡')
    # parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False,
    #                     help='选择模型参数')
    parser.add_argument('--tokenizer_path', default='cache/vocab_small_terry_ai.txt', type=str, required=False, help='选择词库')
    parser.add_argument('--raw_data_path', default='data/train.json', type=str, required=False, help='原始训练语料')
    parser.add_argument('--tokenized_data_path', default='data/tokenized/', type=str, required=False,
                        help='tokenized语料存放位置')
    parser.add_argument('--raw', action='store_true', help='是否先做tokenize')
    parser.add_argument('--epochs', default=5, type=int, required=False, help='训练循环')
    parser.add_argument('--batch_size', default=2, type=int, required=False, help='训练batch size')
    parser.add_argument('--lr', default=1e-8, type=float, required=False, help='学习率')
    parser.add_argument('--warmup_steps', default=2000, type=int, required=False, help='warm up步数')
    parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss')
    parser.add_argument('--stride', default=500, type=int, required=False, help=' 向前跨越的长度')
    parser.add_argument('--dim', default=1024, type=int, required=False, help='训练时取训练数据的窗口步长单个样本长度')
    parser.add_argument('--gradient_accumulation', default=5, type=int, required=False, help='梯度积累')
    parser.add_argument('--fp16', action='store_true', help='混合精度')
    parser.add_argument('--fp16_opt_level', default='O1', type=str, required=False)
    parser.add_argument('--max_grad_norm', default=1.0, type=float, required=False)
    parser.add_argument('--num_pieces', default=10, type=int, required=False, help='将训练语料分成多少份')
    parser.add_argument('--min_length', default=64, type=int, required=False, help='最短收录文章长度')
    parser.add_argument('--output_dir', default='model/', type=str, required=False, help='模型输出路径')
    parser.add_argument('--pretrained_model', default='', type=str, required=False, help='模型训练起点路径')
    # parser.add_argument('--writer_dir', default='tensorboard_summary/', type=str, required=False, help='Tensorboard路径')
    parser.add_argument('--segment', action='store_true', help='中文以词为单位')
    parser.add_argument('--bpe_token', action='store_true', help='subword')

    # parser.add_argument('--dim', default=1024, type=int, required=False, help='dim')
    parser.add_argument('--depth', default=12, type=int, required=False, help='depth')
    parser.add_argument('--full_attn_thres', default=1024, type=int, required=False, help='full_attn_thres')
    parser.add_argument('--max_seq_len', default=4096, type=int, required=False, help='max_seq_len')
    # parser.add_argument('--encoder_json', default="tokenizations/encoder.json", type=str, help="encoder.json")
    # parser.add_argument('--vocab_bpe', default="tokenizations/vocab.bpe", type=str, help="vocab.bpe")

    args = parser.parse_args()
    full_tokenizer=tokenizer_plus(args.tokenizer_path)
    config_file=os.path.join(args.output_dir,'config.json')
    Config=tkitJson.Config(config_file)
    new_conf={'num_tokens':full_tokenizer.vocab_size,
    'dim': args.dim, #和窗口长度一样 
    'depth' : args.depth,
    'max_seq_len' :  args.max_seq_len,
    'lsh_dropout' : 0.1,
    'causal' : True,
    'full_attn_thres' : args.full_attn_thres,
    'stride': args.stride,  #滑块长度
    }
    print("new_conf:",new_conf)
    Config.save(new_conf)
    #复制词典
    shutil.copy(args.tokenizer_path,os.path.join(args.output_dir,'vocab.txt'))
    
    print('args:\n' + args.__repr__())

    # if args.segment:
    #     from tokenizations import tokenization_bert_word_level as tokenization_bert
    # else:
    #     from tokenizations import tokenization_bert

    os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3' # 此处设置程序使用哪些显卡

    # model_config = transformers.modeling_gpt2.GPT2Config.from_json_file(args.model_config)
    # print('config:\n' + model_config.to_json_string())

    # dim = model_config.dim
    # if args.bpe_token:
    #     full_tokenizer = get_encoder(args.encoder_json, args.vocab_bpe)
    # else:
    # full_tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path)
    # full_tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path)

    # full_tokenizer.max_len = dim

    # if args.device==''
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    #强制使用cpu
    device = args.device

    print('using device:', device)

    raw_data_path = args.raw_data_path
    tokenized_data_path = args.tokenized_data_path
    raw = args.raw  # 选择是否从零开始构建数据集
    pretrained_model = args.pretrained_model
    epochs = args.epochs
    batch_size = args.batch_size
    lr = args.lr
    warmup_steps = args.warmup_steps
    log_step = args.log_step
    stride = args.stride
    dim=args.dim
    if stride>= dim:
        stride=dim/2-2
    gradient_accumulation = args.gradient_accumulation
    
    # fp16 = args.fp16  # 不支持半精度的显卡请勿打开
    # fp16_opt_level = args.fp16_opt_level
    max_grad_norm = args.max_grad_norm
    num_pieces = args.num_pieces
    min_length = args.min_length
    output_dir = args.output_dir
    # tb_writer = SummaryWriter(log_dir=args.writer_dir)

    # 加载之前的模型路径
    model_path=os.path.join(pretrained_model, 'model.pt')
    optimizer_path= os.path.join(pretrained_model, 'optimizer.pt')
    scheduler_path=os.path.join(pretrained_model, 'scheduler.pt')
    # 设置输出
    output_model_path=os.path.join(output_dir, 'model.pt')
    output_optimizer_path= os.path.join(output_dir, 'optimizer.pt')
    output_scheduler_path=os.path.join(output_dir, 'scheduler.pt')
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    if raw:
        print('building files')
        build_files(data_path=raw_data_path, tokenized_data_path=tokenized_data_path, num_pieces=num_pieces,
                    full_tokenizer=full_tokenizer, min_length=min_length)
        print('files built')

    model = ReformerLM(
        num_tokens= full_tokenizer.vocab_size,
        dim = dim, #窗口长度
        depth = args.depth,
        max_seq_len =  args.max_seq_len,
        lsh_dropout = 0.1,
        causal = True,
        full_attn_thres = args.full_attn_thres
    )

    # 0 is used for padding and no loss to be calculated on it
    if device=='cuda':
        model = TrainingWrapper(model, ignore_index = 0, pad_value = 0).to('cuda')
    else:
        model = TrainingWrapper(model, ignore_index = 0, pad_value = 0)

    if os.path.isfile(model_path):
        # if so, load them
        model.load_state_dict(torch.load(model_path))
    else:   
        # pass
        model.train()

    weight_decay=0.0
    # learning_rate=5e-5
    adam_epsilon=1e-8
    # warmup_steps=0
    max_grad_norm=1.0
    max_steps=-1
    # gradient_accumulation_steps=10
    logging_steps=1000
    save_steps=10000
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {
            'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            'weight_decay': weight_decay
        },
        {
            'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            'weight_decay': 0.0
        }
    ]


    full_len = 0
    print('calculating total steps')
    for i in tqdm(range(num_pieces)):
        with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f:
            full_len += len([int(item) for item in f.read().strip().split()])
    total_steps = int(full_len / stride * epochs / batch_size / gradient_accumulation)
    print('total steps = {}'.format(total_steps))



    # total_steps = len(x_train_text)/gradient_accumulation_steps * num_train_epochs
    # t_total=3/1*3
    # optimizer = AdamW(model.parameters(), lr=lr, correct_bias=True)
    optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=adam_epsilon)
    scheduler = get_linear_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=warmup_steps,num_training_steps=total_steps)

    # # checking if another optimizer/scheduler exists
    if os.path.isfile(optimizer_path) and os.path.isfile(scheduler_path):
        # if so, load them
        optimizer.load_state_dict(torch.load(optimizer_path))
        scheduler.load_state_dict(torch.load(scheduler_path))

    print("optimizer",optimizer)
    loss_fn=nn.CrossEntropyLoss()

    
    print('starting training')
    overall_step = 0
    running_loss = 0
    gradient_accumulation_run=0
    for epoch in range(epochs):
        print('epoch {}'.format(epoch + 1))
        now = datetime.now()
        print('time: {}'.format(now))
        x = np.linspace(0, num_pieces - 1, num_pieces, dtype=np.int32)
        random.shuffle(x)
        # piece_num = 0

        # model.zero_grad()   # reset gradient
        # for piece_num, i in tqdm(enumerate( x)):
        for piece_num, i in enumerate( x):
            with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f:
                line = f.read().strip()
            tokens = line.split()
            tokens = [int(token) for token in tokens]
            # print(len(tokens))
            start_point = 0
            samples = []
            #划窗切割数据
            while start_point < len(tokens) - dim:
                samples.append(tokens[start_point: start_point + dim])
                # print(start_point, start_point + dim)
                start_point += stride
            if start_point < len(tokens):
                samples.append(tokens[len(tokens)-dim:])
            # 打乱数据,防止过度拟合
            random.shuffle(samples)
            for step in range(len(samples) // batch_size):  # drop last
                # print(step)
                #  prepare data
                batch = samples[step * batch_size: (step + 1) * batch_size]
                # batch_labels = []
                batch_inputs = []
                for ids in batch:
                    # int_ids_for_labels = [int(x) for x in ids]
                    int_ids_for_inputs = [int(x) for x in ids]
                    # batch_labels.append(int_ids_for_labels)
                    batch_inputs.append(int_ids_for_inputs)
                if device=='cuda':
                    batch_inputs = torch.tensor(batch_inputs).long().to("cuda")
                    # batch_labels = torch.tensor(batch_labels).long().to("cuda")
                else:
                    batch_inputs = torch.tensor(batch_inputs).long()
                    # batch_labels = torch.tensor(batch_labels).long()
                # batch_inputs = torch.tensor(batch_inputs).long().to(device)
                # print(batch_labels)

                # print(len(batch_inputs))
                # print(batch_inputs)
                # print(len(batch_inputs))

                loss = model(batch_inputs, return_loss = True)
                loss = loss/gradient_accumulation   
                loss.backward()
                # print(loss.sum())
                if((gradient_accumulation_run+1)%gradient_accumulation)==0:
                    # optimizer the net
                    optimizer.step()
                    scheduler.step()        # update parameters of net
                    optimizer.zero_grad()        # update parameters of net
                    # scheduler.zero_grad()        # update parameters of net
                    # model.zero_grad()   # reset gradient
                    end = datetime.now()
                    print("epoch:",epoch + 1," piece_num:",piece_num,'/',num_pieces," step:",overall_step+1,'/',total_steps," step完成比例:",(overall_step+1)/total_steps," loss:",loss.item(),'Time',end-now)
                overall_step+=1
                gradient_accumulation_run=gradient_accumulation_run+1

                # scheduler.step()
                # model.zero_grad()
            # end = datetime.now()
            # print("one piece:",end-now," s")

            torch.save(model.state_dict(),  output_model_path)
            torch.save(optimizer.state_dict(), output_optimizer_path)
            torch.save(scheduler.state_dict(),  output_scheduler_path)
    model_cpu_path=os.path.join(output_dir, 'model_cpu.pt')
    torch.save(model.cpu().state_dict(), model_cpu_path)
def train(device='cpu',
          output_dir='model',
          epochs=5,
          save_step=5,
          batch_size=4):

    model = ReformerLM(num_tokens=13137,
                       dim=128,
                       depth=12,
                       max_seq_len=4096,
                       lsh_dropout=0.1,
                       causal=True,
                       full_attn_thres=128)
    model = TrainingWrapper(model, ignore_index=0, pad_value=0).to(device)
    # output_dir="model"
    model_cpu_path = os.path.join(output_dir, 'model_cpu.pt')
    try:
        model.load_state_dict(torch.load(model_cpu_path))
    except:
        pass

    model.train()
    optimizer = AdamW(params=model.parameters())
    optimizer_path = os.path.join(output_dir, 'optimizer.pt')
    try:
        optimizer.load_state_dict(torch.load(optimizer_path))
    except:
        pass
    print(optimizer)
    total_loss = 0.0
    # batch_size=4

    loss = []

    data = []
    for it in get_data("data/train.json", tokenizer):
        data.append(it)
    # data=data[:1000]
    loss_fn = nn.CrossEntropyLoss()  # -100 index = padding token
    for n in tqdm(range(epochs)):
        # print(n)
        random.shuffle(data)
        inputs = []
        labels = []
        for i, it in enumerate(data):
            # print("it",it)
            inputs.append(it['keywords'])
            labels.append(it['text'])
            if i % batch_size == 0 and i != 0:
                # print(it)

                inputs_batch = torch.tensor(inputs).long().to(device)

                labels_batch = torch.tensor(labels).long().to(device)
                # print(inputs_batch)
                inputs = []
                labels = []

                # inputs = torch.tensor(it['keywords']).long()
                # labels = torch.tensor(it['text']).long()
                # print("inputs",inputs)
                pred = model(inputs_batch)
                mlm_loss = loss_fn(pred.view(-1, tokenizer.vocab_size),
                                   labels_batch.view(-1))

                total_loss += mlm_loss.item()
                loss.append(mlm_loss.item())
                print('loss', mlm_loss.item())
                mlm_loss.backward()
                optimizer.step()
                model.zero_grad()
                # output_dir="model"
            if i % save_step == 0 and i != 0:
                model_cpu_path = os.path.join(output_dir, 'model_cpu.pt')
                optimizer_path = os.path.join(output_dir, 'optimizer.pt')
                torch.save(model.state_dict(), model_cpu_path)
                torch.save(optimizer.state_dict(), optimizer_path)
        model_cpu_path = os.path.join(output_dir, 'model_cpu.pt')
        optimizer_path = os.path.join(output_dir, 'optimizer.pt')
        torch.save(model.state_dict(), model_cpu_path)
        torch.save(optimizer.state_dict(), optimizer_path)