Exemplo n.º 1
0
    def __init__(self, dim, ignore_index=-100, pad_value=0, **kwargs):
        super().__init__()
        enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs)

        assert (
            "return_embedding" not in enc_kwargs
        ), "you cannot manually set the return embeddings flag for the encoder"
        assert ("dim" not in dec_kwargs and "dim" not in enc_kwargs
                ), "you must set the dim for both encoder and decoder"

        enc_kwargs["dim"] = dec_kwargs["dim"] = dim
        enc_kwargs["return_embeddings"] = True
        dec_kwargs["causal"] = True

        enc_kwargs.setdefault("bucket_size", 64)
        dec_kwargs.setdefault("bucket_size", enc_kwargs["bucket_size"] * 2)

        enc = ReformerLM(**enc_kwargs)
        dec = ReformerLM(**dec_kwargs)

        self.enc = TrainingWrapper(enc,
                                   ignore_index=ignore_index,
                                   pad_value=pad_value)
        self.dec = TrainingWrapper(dec,
                                   ignore_index=ignore_index,
                                   pad_value=pad_value)
    def __init__(self, dim, ignore_index=-100, pad_value=0, **kwargs):
        super().__init__()
        enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs)

        assert 'return_embedding' not in enc_kwargs, 'you cannot manually set the return embeddings flag for the encoder'
        assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder'

        enc_kwargs['dim'] = dec_kwargs['dim'] = dim
        enc_kwargs['return_embeddings'] = True

        enc_kwargs.setdefault('bucket_size', 64)
        dec_kwargs.setdefault('bucket_size', enc_kwargs['bucket_size'] * 2)

        enc = ReformerLM(**enc_kwargs)
        dec = ReformerLM(**dec_kwargs)

        self.enc = TrainingWrapper(enc,
                                   ignore_index=ignore_index,
                                   pad_value=pad_value)
        self.dec = TrainingWrapper(dec,
                                   ignore_index=ignore_index,
                                   pad_value=pad_value)
Exemplo n.º 3
0
def gen(text):
    model = ReformerLM(num_tokens=13137,
                       dim=128,
                       depth=12,
                       max_seq_len=4096,
                       lsh_dropout=0.1,
                       causal=True,
                       full_attn_thres=128)
    model = TrainingWrapper(model, ignore_index=0, pad_value=0).cpu()
    output_dir = "model"
    model_cpu_path = os.path.join(output_dir, 'model_cpu.pt')
    model.load_state_dict(torch.load(model_cpu_path))
    initial = auto_encode(text)
    #   print(initial)
    sample = model.generate(
        initial, 10, temperature=1., filter_thres=0.9, eos_token=1
    )  # assume end token is 1, or omit and it will sample up to 100
    #   print(sample)
    # print(sample.shape) # (1, <=100) token ids
    text = tokenizer.convert_ids_to_tokens(sample.tolist()[0])
    print(text)
Exemplo n.º 4
0
def train_encdec_v1(
    input_lang, target_lang, dim, bucket_size, depth, heads, n_hashes,
    vir_seq_len, ff_chunks, attn_chunks, mol_seq_len, cmd_args, train_dataset,
    test_dataset, output_folder, train_batch_size, epochs, validate_every,
    save_every, deepspeed_optimizer, use_full_attn, gradient_accumulation_steps
):  #zero_optimization, #unused for now. Use this flag to create IF statement for Zero Compatibility if needed
    print('Axial Embedding shape:', compute_axial_position_shape(vir_seq_len))
    encoder = ReformerLM(
        num_tokens=input_lang.n_words,
        dim=dim,
        bucket_size=bucket_size,
        depth=depth,
        heads=heads,
        n_hashes=n_hashes,
        max_seq_len=vir_seq_len,
        ff_chunks=ff_chunks,
        attn_chunks=attn_chunks,
        weight_tie=True,
        weight_tie_embedding=True,
        axial_position_emb=True,
        axial_position_shape=compute_axial_position_shape(vir_seq_len),
        axial_position_dims=(dim // 2, dim // 2),
        return_embeddings=True,
        use_full_attn=use_full_attn).to(device)

    decoder = ReformerLM(
        num_tokens=target_lang.n_words,
        dim=dim,
        bucket_size=bucket_size,
        depth=depth,
        heads=heads,
        n_hashes=n_hashes,
        ff_chunks=ff_chunks,
        attn_chunks=attn_chunks,
        max_seq_len=mol_seq_len,
        axial_position_emb=True,
        axial_position_shape=compute_axial_position_shape(mol_seq_len),
        axial_position_dims=(dim // 2, dim // 2),
        weight_tie=True,
        weight_tie_embedding=True,
        causal=True,
        use_full_attn=use_full_attn).to(device)

    encoder = TrainingWrapper(encoder, ignore_index=PAD_IDX,
                              pad_value=PAD_IDX).to(device)
    decoder = TrainingWrapper(decoder, ignore_index=PAD_IDX,
                              pad_value=PAD_IDX).to(device)

    encoder_params = filter(lambda p: p.requires_grad, encoder.parameters())
    decoder_params = filter(lambda p: p.requires_grad, decoder.parameters())

    if deepspeed_optimizer == False:
        print('No DeepSpeed optimizer found. Using RangerLars.')
        encoder_optimizer = RangerLars(encoder.parameters())
        decoder_optimizer = RangerLars(decoder.parameters())

        encoder_engine, encoder_optimizer, trainloader, _ = deepspeed.initialize(
            args=cmd_args,
            model=encoder,
            optimizer=encoder_optimizer,
            model_parameters=encoder_params,
            training_data=train_dataset,
            dist_init_required=True)

        decoder_engine, decoder_optimizer, testloader, _ = deepspeed.initialize(
            args=cmd_args,
            model=decoder,
            optimizer=decoder_optimizer,
            model_parameters=decoder_params,
            training_data=test_dataset,
            dist_init_required=False)
    else:
        print('Found optimizer in the DeepSpeed configurations. Using it.')
        encoder_engine, encoder_optimizer, trainloader, _ = deepspeed.initialize(
            args=cmd_args,
            model=encoder,
            model_parameters=encoder_params,
            training_data=train_dataset,
            dist_init_required=True)
        decoder_engine, decoder_optimizer, testloader, _ = deepspeed.initialize(
            args=cmd_args,
            model=decoder,
            model_parameters=decoder_params,
            training_data=test_dataset,
            dist_init_required=False)

    SAVE_DIR = os.sep.join([output_folder, 'saved_model'])
    os.makedirs(SAVE_DIR, exist_ok=True)

    try:
        enc_ckp_max = np.max([
            int(ckp) for ckp in os.listdir(os.sep.join([SAVE_DIR, 'encoder']))
        ])
    except Exception as e:
        print('Exception:', e)
        enc_ckp_max = 0

    try:
        dec_ckp_max = np.max([
            int(ckp) for ckp in os.listdir(os.sep.join([SAVE_DIR, 'decoder']))
        ])
    except:
        dec_ckp_max = 0

    _, encoder_client_sd = encoder_engine.load_checkpoint(
        os.sep.join([SAVE_DIR, 'encoder']), enc_ckp_max)
    _, decoder_client_sd = decoder_engine.load_checkpoint(
        os.sep.join([SAVE_DIR, 'decoder']), dec_ckp_max)

    gpus_mini_batch = (train_batch_size // gradient_accumulation_steps
                       ) // torch.cuda.device_count()
    print('gpus_mini_batch:', gpus_mini_batch,
          'with gradient_accumulation_steps:', gradient_accumulation_steps)

    log_file = open(os.sep.join([output_folder, 'training_log.log']), 'a')
    log_file.write(
        "\n\n\n{}\tStarting new training from chekpoint: Encoder-{} | Decoder-{}\n"
        .format(datetime.datetime.now(), enc_ckp_max, dec_ckp_max))
    log_file.flush()

    for eph in range(epochs):
        print('Starting Epoch: {}'.format(eph))
        for i, pair in enumerate(tqdm(trainloader)):
            tr_step = ((eph * len(trainloader)) + i) + 1

            src = pair[0]
            trg = pair[1]
            encoder_engine.train()
            decoder_engine.train()
            src = src.to(encoder_engine.local_rank)
            trg = trg.to(decoder_engine.local_rank)

            enc_keys = encoder_engine(src)
            loss = decoder_engine(trg, keys=enc_keys, return_loss=True)
            loss.backward()

            decoder_engine.step()
            encoder_engine.step()

            print('Training Loss:', loss.item())
            if tr_step % validate_every == 0:
                val_loss = []
                for pair in tqdm(testloader):
                    encoder_engine.eval()
                    decoder_engine.eval()
                    with torch.no_grad():
                        ts_src = pair[0]
                        ts_trg = pair[1]

                        ts_src = ts_src.to(encoder_engine.local_rank)
                        ts_trg = ts_trg.to(decoder_engine.local_rank)

                        enc_keys = encoder_engine(ts_src)
                        loss = decoder_engine(ts_trg,
                                              keys=enc_keys,
                                              return_loss=True)
                        val_loss.append(loss.item())

                print(
                    f'\tValidation Loss: AVG: {np.mean(val_loss)}, MEDIAN: {np.median(val_loss)}, STD: {np.std(val_loss)} '
                )
                log_file.write(
                    'Step: {}\tTraining Loss:{}\t Validation LOSS: AVG: {}| MEDIAN: {}| STD: {}\n'
                    .format(i, loss.item(), np.mean(val_loss),
                            np.median(val_loss), np.std(val_loss)))
            else:
                log_file.write('Step: {}\tTraining Loss:{}\n'.format(
                    i, loss.item()))

            log_file.flush()

            if tr_step % save_every == 0:
                print('\tSaving Checkpoint')
                enc_ckpt_id = str(enc_ckp_max + tr_step + 1)
                dec_ckpt_id = str(dec_ckp_max + tr_step + 1)
                encoder_engine.save_checkpoint(
                    os.sep.join([SAVE_DIR, 'encoder']), enc_ckpt_id)
                decoder_engine.save_checkpoint(
                    os.sep.join([SAVE_DIR, 'decoder']), dec_ckpt_id)

    log_file.close()
    print('\tSaving Final Checkpoint')
    enc_ckpt_id = str(enc_ckp_max + tr_step + 1)
    dec_ckpt_id = str(dec_ckp_max + tr_step + 1)
    encoder_engine.save_checkpoint(os.sep.join([SAVE_DIR, 'encoder']),
                                   enc_ckpt_id)
    decoder_engine.save_checkpoint(os.sep.join([SAVE_DIR, 'decoder']),
                                   dec_ckpt_id)
Exemplo n.º 5
0
model = ReformerLM(
    dim=512,
    depth=6,
    max_seq_len=SEQ_LEN,
    num_tokens=256,
    heads=8,
    bucket_size=64,
    n_hashes=4,
    ff_chunks=10,
    lsh_dropout=0.1,
    weight_tie=True,
    causal=True,
    use_full_attn=False  # set this to true for comparison with full attention
)

model = TrainingWrapper(model)
model.cuda()

# prepare enwik8 data

with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)


class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len
Exemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default='cuda', type=str, required=False, help='设置使用哪些显卡')
    # parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False,
    #                     help='选择模型参数')
    parser.add_argument('--tokenizer_path', default='cache/vocab_small_terry_ai.txt', type=str, required=False, help='选择词库')
    parser.add_argument('--raw_data_path', default='data/train.json', type=str, required=False, help='原始训练语料')
    parser.add_argument('--tokenized_data_path', default='data/tokenized/', type=str, required=False,
                        help='tokenized语料存放位置')
    parser.add_argument('--raw', action='store_true', help='是否先做tokenize')
    parser.add_argument('--epochs', default=5, type=int, required=False, help='训练循环')
    parser.add_argument('--batch_size', default=2, type=int, required=False, help='训练batch size')
    parser.add_argument('--lr', default=1e-8, type=float, required=False, help='学习率')
    parser.add_argument('--warmup_steps', default=2000, type=int, required=False, help='warm up步数')
    parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss')
    parser.add_argument('--stride', default=500, type=int, required=False, help=' 向前跨越的长度')
    parser.add_argument('--dim', default=1024, type=int, required=False, help='训练时取训练数据的窗口步长单个样本长度')
    parser.add_argument('--gradient_accumulation', default=5, type=int, required=False, help='梯度积累')
    parser.add_argument('--fp16', action='store_true', help='混合精度')
    parser.add_argument('--fp16_opt_level', default='O1', type=str, required=False)
    parser.add_argument('--max_grad_norm', default=1.0, type=float, required=False)
    parser.add_argument('--num_pieces', default=10, type=int, required=False, help='将训练语料分成多少份')
    parser.add_argument('--min_length', default=64, type=int, required=False, help='最短收录文章长度')
    parser.add_argument('--output_dir', default='model/', type=str, required=False, help='模型输出路径')
    parser.add_argument('--pretrained_model', default='', type=str, required=False, help='模型训练起点路径')
    # parser.add_argument('--writer_dir', default='tensorboard_summary/', type=str, required=False, help='Tensorboard路径')
    parser.add_argument('--segment', action='store_true', help='中文以词为单位')
    parser.add_argument('--bpe_token', action='store_true', help='subword')

    # parser.add_argument('--dim', default=1024, type=int, required=False, help='dim')
    parser.add_argument('--depth', default=12, type=int, required=False, help='depth')
    parser.add_argument('--full_attn_thres', default=1024, type=int, required=False, help='full_attn_thres')
    parser.add_argument('--max_seq_len', default=4096, type=int, required=False, help='max_seq_len')
    # parser.add_argument('--encoder_json', default="tokenizations/encoder.json", type=str, help="encoder.json")
    # parser.add_argument('--vocab_bpe', default="tokenizations/vocab.bpe", type=str, help="vocab.bpe")

    args = parser.parse_args()
    full_tokenizer=tokenizer_plus(args.tokenizer_path)
    config_file=os.path.join(args.output_dir,'config.json')
    Config=tkitJson.Config(config_file)
    new_conf={'num_tokens':full_tokenizer.vocab_size,
    'dim': args.dim, #和窗口长度一样 
    'depth' : args.depth,
    'max_seq_len' :  args.max_seq_len,
    'lsh_dropout' : 0.1,
    'causal' : True,
    'full_attn_thres' : args.full_attn_thres,
    'stride': args.stride,  #滑块长度
    }
    print("new_conf:",new_conf)
    Config.save(new_conf)
    #复制词典
    shutil.copy(args.tokenizer_path,os.path.join(args.output_dir,'vocab.txt'))
    
    print('args:\n' + args.__repr__())

    # if args.segment:
    #     from tokenizations import tokenization_bert_word_level as tokenization_bert
    # else:
    #     from tokenizations import tokenization_bert

    os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3' # 此处设置程序使用哪些显卡

    # model_config = transformers.modeling_gpt2.GPT2Config.from_json_file(args.model_config)
    # print('config:\n' + model_config.to_json_string())

    # dim = model_config.dim
    # if args.bpe_token:
    #     full_tokenizer = get_encoder(args.encoder_json, args.vocab_bpe)
    # else:
    # full_tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path)
    # full_tokenizer = BertTokenizer.from_pretrained(args.tokenizer_path)

    # full_tokenizer.max_len = dim

    # if args.device==''
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    #强制使用cpu
    device = args.device

    print('using device:', device)

    raw_data_path = args.raw_data_path
    tokenized_data_path = args.tokenized_data_path
    raw = args.raw  # 选择是否从零开始构建数据集
    pretrained_model = args.pretrained_model
    epochs = args.epochs
    batch_size = args.batch_size
    lr = args.lr
    warmup_steps = args.warmup_steps
    log_step = args.log_step
    stride = args.stride
    dim=args.dim
    if stride>= dim:
        stride=dim/2-2
    gradient_accumulation = args.gradient_accumulation
    
    # fp16 = args.fp16  # 不支持半精度的显卡请勿打开
    # fp16_opt_level = args.fp16_opt_level
    max_grad_norm = args.max_grad_norm
    num_pieces = args.num_pieces
    min_length = args.min_length
    output_dir = args.output_dir
    # tb_writer = SummaryWriter(log_dir=args.writer_dir)

    # 加载之前的模型路径
    model_path=os.path.join(pretrained_model, 'model.pt')
    optimizer_path= os.path.join(pretrained_model, 'optimizer.pt')
    scheduler_path=os.path.join(pretrained_model, 'scheduler.pt')
    # 设置输出
    output_model_path=os.path.join(output_dir, 'model.pt')
    output_optimizer_path= os.path.join(output_dir, 'optimizer.pt')
    output_scheduler_path=os.path.join(output_dir, 'scheduler.pt')
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    if raw:
        print('building files')
        build_files(data_path=raw_data_path, tokenized_data_path=tokenized_data_path, num_pieces=num_pieces,
                    full_tokenizer=full_tokenizer, min_length=min_length)
        print('files built')

    model = ReformerLM(
        num_tokens= full_tokenizer.vocab_size,
        dim = dim, #窗口长度
        depth = args.depth,
        max_seq_len =  args.max_seq_len,
        lsh_dropout = 0.1,
        causal = True,
        full_attn_thres = args.full_attn_thres
    )

    # 0 is used for padding and no loss to be calculated on it
    if device=='cuda':
        model = TrainingWrapper(model, ignore_index = 0, pad_value = 0).to('cuda')
    else:
        model = TrainingWrapper(model, ignore_index = 0, pad_value = 0)

    if os.path.isfile(model_path):
        # if so, load them
        model.load_state_dict(torch.load(model_path))
    else:   
        # pass
        model.train()

    weight_decay=0.0
    # learning_rate=5e-5
    adam_epsilon=1e-8
    # warmup_steps=0
    max_grad_norm=1.0
    max_steps=-1
    # gradient_accumulation_steps=10
    logging_steps=1000
    save_steps=10000
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {
            'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            'weight_decay': weight_decay
        },
        {
            'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            'weight_decay': 0.0
        }
    ]


    full_len = 0
    print('calculating total steps')
    for i in tqdm(range(num_pieces)):
        with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f:
            full_len += len([int(item) for item in f.read().strip().split()])
    total_steps = int(full_len / stride * epochs / batch_size / gradient_accumulation)
    print('total steps = {}'.format(total_steps))



    # total_steps = len(x_train_text)/gradient_accumulation_steps * num_train_epochs
    # t_total=3/1*3
    # optimizer = AdamW(model.parameters(), lr=lr, correct_bias=True)
    optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=adam_epsilon)
    scheduler = get_linear_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=warmup_steps,num_training_steps=total_steps)

    # # checking if another optimizer/scheduler exists
    if os.path.isfile(optimizer_path) and os.path.isfile(scheduler_path):
        # if so, load them
        optimizer.load_state_dict(torch.load(optimizer_path))
        scheduler.load_state_dict(torch.load(scheduler_path))

    print("optimizer",optimizer)
    loss_fn=nn.CrossEntropyLoss()

    
    print('starting training')
    overall_step = 0
    running_loss = 0
    gradient_accumulation_run=0
    for epoch in range(epochs):
        print('epoch {}'.format(epoch + 1))
        now = datetime.now()
        print('time: {}'.format(now))
        x = np.linspace(0, num_pieces - 1, num_pieces, dtype=np.int32)
        random.shuffle(x)
        # piece_num = 0

        # model.zero_grad()   # reset gradient
        # for piece_num, i in tqdm(enumerate( x)):
        for piece_num, i in enumerate( x):
            with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f:
                line = f.read().strip()
            tokens = line.split()
            tokens = [int(token) for token in tokens]
            # print(len(tokens))
            start_point = 0
            samples = []
            #划窗切割数据
            while start_point < len(tokens) - dim:
                samples.append(tokens[start_point: start_point + dim])
                # print(start_point, start_point + dim)
                start_point += stride
            if start_point < len(tokens):
                samples.append(tokens[len(tokens)-dim:])
            # 打乱数据,防止过度拟合
            random.shuffle(samples)
            for step in range(len(samples) // batch_size):  # drop last
                # print(step)
                #  prepare data
                batch = samples[step * batch_size: (step + 1) * batch_size]
                # batch_labels = []
                batch_inputs = []
                for ids in batch:
                    # int_ids_for_labels = [int(x) for x in ids]
                    int_ids_for_inputs = [int(x) for x in ids]
                    # batch_labels.append(int_ids_for_labels)
                    batch_inputs.append(int_ids_for_inputs)
                if device=='cuda':
                    batch_inputs = torch.tensor(batch_inputs).long().to("cuda")
                    # batch_labels = torch.tensor(batch_labels).long().to("cuda")
                else:
                    batch_inputs = torch.tensor(batch_inputs).long()
                    # batch_labels = torch.tensor(batch_labels).long()
                # batch_inputs = torch.tensor(batch_inputs).long().to(device)
                # print(batch_labels)

                # print(len(batch_inputs))
                # print(batch_inputs)
                # print(len(batch_inputs))

                loss = model(batch_inputs, return_loss = True)
                loss = loss/gradient_accumulation   
                loss.backward()
                # print(loss.sum())
                if((gradient_accumulation_run+1)%gradient_accumulation)==0:
                    # optimizer the net
                    optimizer.step()
                    scheduler.step()        # update parameters of net
                    optimizer.zero_grad()        # update parameters of net
                    # scheduler.zero_grad()        # update parameters of net
                    # model.zero_grad()   # reset gradient
                    end = datetime.now()
                    print("epoch:",epoch + 1," piece_num:",piece_num,'/',num_pieces," step:",overall_step+1,'/',total_steps," step完成比例:",(overall_step+1)/total_steps," loss:",loss.item(),'Time',end-now)
                overall_step+=1
                gradient_accumulation_run=gradient_accumulation_run+1

                # scheduler.step()
                # model.zero_grad()
            # end = datetime.now()
            # print("one piece:",end-now," s")

            torch.save(model.state_dict(),  output_model_path)
            torch.save(optimizer.state_dict(), output_optimizer_path)
            torch.save(scheduler.state_dict(),  output_scheduler_path)
    model_cpu_path=os.path.join(output_dir, 'model_cpu.pt')
    torch.save(model.cpu().state_dict(), model_cpu_path)
Exemplo n.º 7
0
conf = Config.read()

# tokenizer = BertTokenizer.from_pretrained(pretrained_weights)
tokenizer = tokenizer_plus(pretrained_weights)
model = ReformerLM(num_tokens=conf['num_tokens'],
                   dim=conf['dim'],
                   depth=conf['depth'],
                   max_seq_len=conf['max_seq_len'],
                   lsh_dropout=conf['lsh_dropout'],
                   causal=conf['causal'],
                   full_attn_thres=conf['full_attn_thres'])

model_path = os.path.join(output_dir, 'model.pt')

if device == 'cuda':
    model = TrainingWrapper(model, ignore_index=0, pad_value=0).cuda()
    if os.path.isfile(model_path):
        # if so, load them
        # print('++++'*20)
        model.load_state_dict(torch.load(model_path)).cuda()
else:
    model = TrainingWrapper(model, ignore_index=0, pad_value=0).cpu()
    # print(model)
    # print(model.cpu().state_dict())

    # print('++++'*20)
    if os.path.isfile(model_path):
        # if so, load them
        # print('++++'*20)
        print("加载模型")
        model.load_state_dict(torch.load(model_path))
Exemplo n.º 8
0
def main():
    cmd_args = add_argument()

    path_to_file_tr = cmd_args.path_to_file_tr
    path_to_file_ts = cmd_args.path_to_file_ts

    min_len_mol = cmd_args.min_len_mol
    max_len_mol = cmd_args.max_len_mol

    num_examples_tr = cmd_args.num_examples_tr
    num_examples_ts = cmd_args.num_examples_ts

    train_batch_size = json.load(open(cmd_args.ds_conf))['train_batch_size']
    gradient_accumulation_steps = json.load(open(
        cmd_args.ds_conf))['gradient_accumulation_steps']

    deepspeed_optimizer = True if json.load(open(cmd_args.ds_conf)).get(
        'optimizer', None) is not None else False

    epochs = cmd_args.epochs
    emb_dim = cmd_args.emb_dim
    dim = cmd_args.dim
    bucket_size = cmd_args.bucket_size
    depth = cmd_args.depth
    heads = cmd_args.heads
    n_hashes = cmd_args.n_hashes
    ff_chunks = cmd_args.ff_chunks
    attn_chunks = cmd_args.attn_chunks
    validate_every = cmd_args.validate_every
    save_every = cmd_args.save_every
    output_folder = cmd_args.output_folder

    use_full_attn = cmd_args.use_full_attn
    mrpc_test = cmd_args.mrpc_test
    use_deepspeed = cmd_args.use_deepspeed

    os.makedirs(output_folder, exist_ok=True)

    pickle.dump(cmd_args,
                open(os.sep.join([output_folder, 'training_conf.pkl']), 'wb'))

    MIN_LENGTH_MOL = min_len_mol
    MAX_LENGTH_MOL = max_len_mol  # 2048
    NUM_EXAMPLES_TR = num_examples_tr  # 1024
    NUM_EXAMPLES_TS = num_examples_ts  # 1024
    N_EPOCHS = epochs  # 10
    VALIDATE_EVERY = validate_every
    SAVE_EVERY = save_every

    MOL_SEQ_LEN = MAX_LENGTH_MOL  # output_lang.max_len if (output_lang.max_len % 2) == 0  else output_lang.max_len + 1 # ??

    saved_mol_lang = os.sep.join([output_folder, 'mol_lang.pkl'])

    MAX_LENGTH_MOL = cmd_args.max_len_mol

    saved_target_lang = os.sep.join([output_folder, 'mol_lang.pkl'])

    if mrpc_test:
        mol_lang, tr_samples, ts_samples = readMRPC(
            molecule_file_tr=path_to_file_tr,
            molecule_file_ts=path_to_file_ts,
            saved_molecule_lang=saved_target_lang,
            num_examples_tr=NUM_EXAMPLES_TR,
            num_examples_ts=NUM_EXAMPLES_TS,
            min_len_molecule=MIN_LENGTH_MOL,
            max_len_molecule=MAX_LENGTH_MOL,
            shuffle=True)
    else:
        mol_lang, tr_samples, ts_samples = readMolecules(
            molecule_file_tr=path_to_file_tr,
            molecule_file_ts=path_to_file_ts,
            saved_molecule_lang=saved_target_lang,
            num_examples_tr=NUM_EXAMPLES_TR,
            num_examples_ts=NUM_EXAMPLES_TS,
            min_len_molecule=MIN_LENGTH_MOL,
            max_len_molecule=MAX_LENGTH_MOL,
            shuffle=True)

    pickle.dump(mol_lang, open(saved_mol_lang, 'wb'))

    train_dataset = MolecularSimilarityDataset(
        tr_samples, mol_lang, train_batch_size if device == 'cuda' else 1)
    test_dataset = MolecularSimilarityDataset(
        ts_samples, mol_lang, train_batch_size if device == 'cuda' else 1)

    MAX_SEQ_LEN = MOL_SEQ_LEN * 2
    print('Axial Embedding shape:', compute_axial_position_shape(MAX_SEQ_LEN))
    model = ReformerLM(
        num_tokens=mol_lang.n_words,
        dim=dim,
        bucket_size=bucket_size,
        depth=depth,
        heads=heads,
        n_hashes=n_hashes,
        max_seq_len=MAX_SEQ_LEN,
        ff_chunks=ff_chunks,
        attn_chunks=attn_chunks,
        weight_tie=True,
        weight_tie_embedding=True,
        axial_position_emb=True,
        axial_position_shape=compute_axial_position_shape(MAX_SEQ_LEN),
        axial_position_dims=(dim // 2, dim // 2),
        return_embeddings=True,
        use_full_attn=use_full_attn).to(device)

    linear_regressor = Linear(512, 2).to(device)

    model = TrainingWrapper(model, ignore_index=PAD_IDX,
                            pad_value=PAD_IDX).to(device)

    model_params = filter(lambda p: p.requires_grad, model.parameters())
    linear_params = filter(lambda p: p.requires_grad,
                           linear_regressor.parameters())

    SAVE_DIR = os.sep.join([output_folder, 'saved_model'])
    os.makedirs(SAVE_DIR, exist_ok=True)

    try:
        model_ckp_max = np.max(
            [int(ckp) for ckp in os.listdir(os.sep.join([SAVE_DIR, 'model']))])
    except:
        model_ckp_max = 0

    gpus_mini_batch = (train_batch_size // gradient_accumulation_steps
                       ) // torch.cuda.device_count()
    print('gpus_mini_batch:', gpus_mini_batch,
          'with gradient_accumulation_steps:', gradient_accumulation_steps)
    log_file = open(os.sep.join([output_folder, 'training_log.log']), 'a')
    log_file.write(
        "\n\n\n{}\tStarting new training from chekpoint: EncoderDecoder-{}\n".
        format(datetime.datetime.now(), model_ckp_max))
    log_file.flush()

    if use_deepspeed:
        if deepspeed_optimizer == False:
            print('No DeepSpeed optimizer found. Using RangerLars.')
            model_optimizer = RangerLars(model.parameters())
            linear_optimizer = RangerLars(linear_regressor.parameters())

            model_engine, model_optimizer, trainloader, _ = deepspeed.initialize(
                args=cmd_args,
                model=model,
                optimizer=model_optimizer,
                model_parameters=model_params,
                training_data=train_dataset)

            linear_engine, linear_optimizer, _, _ = deepspeed.initialize(
                args=cmd_args,
                model=linear_regressor,
                optimizer=linear_optimizer,
                model_parameters=linear_params)

        else:
            print('Found optimizer in the DeepSpeed configurations. Using it.')
            model_engine, model_optimizer, trainloader, _ = deepspeed.initialize(
                args=cmd_args,
                model=model,
                model_parameters=model_params,
                training_data=train_dataset)
            linear_engine, linear_optimizer, _, _ = deepspeed.initialize(
                args=cmd_args,
                model=linear_regressor,
                model_parameters=linear_params)

        _, model_client_sd = model_engine.load_checkpoint(
            os.sep.join([SAVE_DIR, 'model']), model_ckp_max)

        testloader = model_engine.deepspeed_io(test_dataset)

        ######TO DO
        for eph in range(epochs):
            print('Starting Epoch: {}'.format(eph))
            for i, pair in enumerate(tqdm(trainloader)):
                tr_step = ((eph * len(trainloader)) + i) + 1

                src = pair[0]
                trg = pair[1]

                pickle.dump(src, open('src.pkl', 'wb'))
                pickle.dump(trg, open('trg.pkl', 'wb'))

                model_engine.train()
                linear_engine.train()
                #enc_dec.train()

                src = src.to(model_engine.local_rank)
                trg = trg.to(linear_engine.local_rank)

                print("Sample:", src)
                print("Target:", trg)
                print("Target Shape:", trg.shape)
                print("len Samples:", len(src))

                ## Need to learn how to use masks correctly
                enc_input_mask = torch.tensor(
                    [[1 if idx != PAD_IDX else 0 for idx in smpl]
                     for smpl in src]).bool().to(model_engine.local_rank)

                # context_mask = torch.tensor([[1 for idx in smpl if idx != PAD_IDX] for smpl in trg]).bool().to(device)
                #################

                enc_keys = model_engine(
                    src, return_loss=False, input_mask=enc_input_mask
                )  #enc_input_mask)#, context_mask=context_mask)
                #loss = enc_dec(src, trg, return_loss = True, enc_input_mask = None)#enc_input_mask)#, context_mask=context_mask)

                print('enc_keys shape', enc_keys.shape)
                #enc_keys_cls = enc_keys[:,0:1,:].to(linear_engine.local_rank)#torch.tensor([s[0] for s in enc_keys]).to(linear_engine.local_rank)
                #print('enc_keys_cls shape', enc_keys_cls.shape)
                preds = torch.softmax(linear_engine(enc_keys),
                                      dim=1).to(linear_engine.local_rank)

                print('preds shape', preds.shape)
                #preds = np.array([r[0] for r in results])
                #print('Pred:', preds.shape)
                loss = F.cross_entropy(preds, trg).to(linear_engine.local_rank)
                loss.backward()

                model_engine.step()
                linear_engine.step()

                print('Training Loss:', loss.item())
                if tr_step % validate_every == 0:
                    val_loss = []
                    for pair in tqdm(
                            testloader
                    ):  #Can't use the testloader or I will mess up with the model assignment and it won't learn during training, need to use normal validation instead of parallel one
                        model_engine.eval()
                        linear_engine.eval()
                        with torch.no_grad():
                            ts_src = pair[0]
                            ts_trg = pair[1]

                            pickle.dump(ts_src, open('ts_src.pkl', 'wb'))
                            pickle.dump(ts_trg, open('ts_trg.pkl', 'wb'))

                            ts_src = ts_src.to(model_engine.local_rank)
                            ts_trg = ts_trg.to(linear_engine.local_rank)

                            #ts_src = torch.tensor(np.array([pair[0].numpy()])).to(device)
                            #ts_trg = torch.tensor(np.array([pair[1].numpy()])).to(device)

                            ## Need to learn how to use masks correctly
                            ts_enc_input_mask = torch.tensor([
                                [1 if idx != PAD_IDX else 0 for idx in smpl]
                                for smpl in ts_src
                            ]).bool().to(model_engine.local_rank)
                            #ts_context_mask = torch.tensor([[1 for idx in smpl if idx != PAD_IDX] for smpl in ts_trg]).bool().to(device)

                            # loss = model_engine(
                            #     ts_src,
                            #     ts_trg,
                            #     return_loss=True,
                            #     enc_input_mask=ts_enc_input_mask
                            # )  #ts_enc_input_mask)#, context_mask=ts_context_mask)
                            # #loss = enc_dec(ts_src, ts_trg, return_loss = True, enc_input_mask = None)

                            ts_enc_keys = model_engine(
                                ts_src,
                                return_loss=False,
                                input_mask=ts_enc_input_mask)
                            ts_pred = torch.softmax(
                                linear_engine(ts_enc_keys),
                                dim=1).to(linear_engine.local_rank)
                            loss = F.cross_entropy(ts_pred, ts_trg).to(
                                linear_engine.local_rank)
                            val_loss.append(loss.item())

                    print(
                        f'\tValidation Loss: AVG: {np.mean(val_loss)}, MEDIAN: {np.median(val_loss)}, STD: {np.std(val_loss)} '
                    )
                    log_file.write(
                        'Step: {}\tTraining Loss:{}\t Validation LOSS: AVG: {}| MEDIAN: {}| STD: {}\n'
                        .format(i, loss.item(), np.mean(val_loss),
                                np.median(val_loss), np.std(val_loss)))
                else:
                    log_file.write('Step: {}\tTraining Loss:{}\n'.format(
                        i, loss.item()))

                log_file.flush()

                if tr_step % save_every == 0:
                    print('\tSaving Checkpoint')
                    model_ckpt_id = str(model_ckp_max + tr_step + 1)
                    model_engine.save_checkpoint(
                        os.sep.join([SAVE_DIR, 'model']), model_ckpt_id)

        log_file.close()
        print('\tSaving Final Checkpoint')
        model_ckpt_id = str(model_ckp_max + tr_step + 1)
        model_engine.save_checkpoint(os.sep.join([SAVE_DIR, 'model']),
                                     model_ckpt_id)
    else:
        #model_optimizer = torch.optim.Adam(model.parameters()) # RangerLars(model.parameters())
        #linear_optimizer = torch.optim.Adam(linear_regressor.parameters())  # RangerLars(linear_regressor.parameters())

        model_optimizer = torch.optim.Adam(
            list(model.parameters()) + list(linear_regressor.parameters())
        )  #RangerLars(list(model.parameters())+list(linear_regressor.parameters())) #

        PATH = os.sep.join(
            [SAVE_DIR, 'model',
             str(model_ckp_max), 'sts_model.pt'])
        if os.path.exists(PATH):
            print('********** Found Checkpoint. Loading:', PATH)
            checkpoint = torch.load(PATH)

            model.load_state_dict(checkpoint['model_state_dict'])
            linear_regressor.load_state_dict(checkpoint['linear_state_dict'])
            model_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        trainloader = DataLoader(train_dataset,
                                 batch_size=train_batch_size,
                                 shuffle=False)
        testloader = DataLoader(test_dataset,
                                batch_size=train_batch_size,
                                shuffle=False)
        ######TO DO
        train_loss_list = []
        for eph in range(epochs):
            print('Starting Epoch: {}'.format(eph))
            for i, pair in enumerate(tqdm(trainloader)):
                tr_step = ((eph * len(trainloader)) + i) + 1

                src = pair[0]
                trg = pair[1]

                pickle.dump(src, open('src.pkl', 'wb'))
                pickle.dump(trg, open('trg.pkl', 'wb'))

                model.train()
                linear_regressor.train()
                #enc_dec.train()

                src = src.to(device)
                trg = trg.to(device)

                #print("Sample:", src)
                #print("Target:", trg)
                #print("Target Shape:", trg.shape)
                #print("len Samples:", len(src))

                ## Need to learn how to use masks correctly
                enc_input_mask = torch.tensor(
                    [[1 if idx != PAD_IDX else 0 for idx in smpl]
                     for smpl in src]).bool().to(device)

                # context_mask = torch.tensor([[1 for idx in smpl if idx != PAD_IDX] for smpl in trg]).bool().to(device)
                #################

                enc_keys = model(
                    src, return_loss=False, input_mask=enc_input_mask
                )  #enc_input_mask)#, context_mask=context_mask)
                #loss = enc_dec(src, trg, return_loss = True, enc_input_mask = None)#enc_input_mask)#, context_mask=context_mask)

                #print('enc_keys shape', enc_keys.shape)
                enc_keys_cls = enc_keys[:, 0, :].to(
                    device
                )  #torch.tensor([s[0] for s in enc_keys]).to(linear_engine.local_rank)
                #print('enc_keys_cls shape', enc_keys_cls.shape)
                preds = torch.softmax(linear_regressor(enc_keys_cls),
                                      dim=1).to(device)

                #print('preds shape', preds.shape)
                #preds = np.array([r[0] for r in results])
                #print('Pred:', preds.shape)
                loss = F.cross_entropy(preds, trg).to(device)
                loss.backward()

                model_optimizer.step()
                #linear_optimizer.step()

                train_loss_list.append(loss.item())
                #print('Training Loss:', loss.item())
                if tr_step % validate_every == 0:
                    val_loss = []
                    ACC_list = []
                    MCC_list = []
                    for pair in tqdm(
                            testloader
                    ):  #Can't use the testloader or I will mess up with the model assignment and it won't learn during training, need to use normal validation instead of parallel one
                        model.eval()
                        linear_regressor.eval()
                        with torch.no_grad():
                            ts_src = pair[0]
                            ts_trg = pair[1]

                            pickle.dump(ts_src, open('ts_src.pkl', 'wb'))
                            pickle.dump(ts_trg, open('ts_trg.pkl', 'wb'))

                            ts_src = ts_src.to(device)
                            ts_trg = ts_trg.to(device)

                            #ts_src = torch.tensor(np.array([pair[0].numpy()])).to(device)
                            #ts_trg = torch.tensor(np.array([pair[1].numpy()])).to(device)

                            ## Need to learn how to use masks correctly
                            ts_enc_input_mask = torch.tensor(
                                [[1 if idx != PAD_IDX else 0 for idx in smpl]
                                 for smpl in ts_src]).bool().to(device)
                            #ts_context_mask = torch.tensor([[1 for idx in smpl if idx != PAD_IDX] for smpl in ts_trg]).bool().to(device)

                            # loss = model_engine(
                            #     ts_src,
                            #     ts_trg,
                            #     return_loss=True,
                            #     enc_input_mask=ts_enc_input_mask
                            # )  #ts_enc_input_mask)#, context_mask=ts_context_mask)
                            # #loss = enc_dec(ts_src, ts_trg, return_loss = True, enc_input_mask = None)

                            ts_enc_keys = model(ts_src,
                                                return_loss=False,
                                                input_mask=ts_enc_input_mask)
                            ts_enc_keys_cls = ts_enc_keys[:, 0, :].to(device)

                            ts_pred = torch.softmax(
                                linear_regressor(ts_enc_keys_cls),
                                dim=1).to(device)

                            loss = F.cross_entropy(ts_pred, ts_trg).to(device)

                            ACC, MCC = compute_simple_metrics(ts_pred, ts_trg)
                            ACC_list.append(ACC)
                            MCC_list.append(MCC)

                            val_loss.append(loss.item())

                    print(
                        f'\Train Loss: LAST: {train_loss_list[-1]}, AVG: {np.mean(train_loss_list)}, MEDIAN: {np.median(train_loss_list)}, STD: {np.std(train_loss_list)} '
                    )
                    print(
                        f'\tValidation Loss: AVG: {np.mean(val_loss)}, MEDIAN: {np.median(val_loss)}, STD: {np.std(val_loss)} '
                    )
                    print(
                        f'\tValidation ACC: AVG: {np.mean(ACC_list)}, MEDIAN: {np.median(ACC_list)}, STD: {np.std(ACC_list)} '
                    )
                    print(
                        f'\tValidation MCC: AVG: {np.mean(MCC_list)}, MEDIAN: {np.median(MCC_list)}, STD: {np.std(MCC_list)} '
                    )
                    log_file.write(
                        'Step: {}\tTraining Loss:{}\t Validation LOSS: AVG: {}| MEDIAN: {}| STD: {}\n'
                        .format(i, loss.item(), np.mean(val_loss),
                                np.median(val_loss), np.std(val_loss)))
                else:
                    log_file.write('Step: {}\tTraining Loss:{}\n'.format(
                        i, loss.item()))

                log_file.flush()

                if tr_step % save_every == 0:
                    print('\tSaving Checkpoint')
                    model_ckpt_id = str(model_ckp_max + tr_step + 1)
                    #model_engine.save_checkpoint(os.sep.join([SAVE_DIR, 'model']),
                    #                            model_ckpt_id)
                    PATH = os.sep.join([
                        SAVE_DIR, 'model',
                        str(model_ckpt_id), 'sts_model.pt'
                    ])
                    os.makedirs(os.sep.join(PATH.split(os.sep)[:-1]),
                                exist_ok=True)
                    torch.save(
                        {
                            'step': tr_step,
                            'model_state_dict': model.state_dict(),
                            'linear_state_dict': linear_regressor.state_dict(),
                            'optimizer_state_dict':
                            model_optimizer.state_dict(),
                        }, PATH)

        log_file.close()
        print('\tSaving Final Checkpoint')
        model_ckpt_id = str(model_ckp_max + tr_step + 1)
        #model_engine.save_checkpoint(os.sep.join([SAVE_DIR, 'model']),
        #                            model_ckpt_id)
        PATH = os.sep.join(
            [SAVE_DIR, 'model',
             str(model_ckpt_id), 'sts_model.pt'])
        os.makedirs(os.sep.join(PATH.split(os.sep)[:-1]), exist_ok=True)
        torch.save(
            {
                'step': tr_step,
                'model_state_dict': model.state_dict(),
                'linear_state_dict': linear_regressor.state_dict(),
                'optimizer_state_dict': model_optimizer.state_dict(),
            }, PATH)
Exemplo n.º 9
0
# 加载albert
path = "model/albert_tiny/"
albert_model, full_tokenizer = load_albert(path)

# outputs = albert_model(batch_inputs)

model = ReformerLM(num_tokens=20000,
                   dim=1024,
                   depth=12,
                   max_seq_len=4096,
                   lsh_dropout=0.1,
                   causal=True,
                   full_attn_thres=1024)

# 0 is used for padding and no loss to be calculated on it
model = TrainingWrapper(model, ignore_index=0, pad_value=0)

# the wrapper can handle evenly packed sequences
x_train = randint(0, 20000, (3, 357))

# or if you have a list of uneven sequences, it will be padded for you
x_train = [
    randint(0, 20000, (120, )),
    randint(0, 20000, (253, )),
    randint(0, 20000, (846, ))
]

# when training, set return_loss equal to True
model.train()
loss = model(x_train, return_loss=True)
loss.backward()
Exemplo n.º 10
0
def test_encdec_v1(input_lang, target_lang, dim, bucket_size, depth, heads,
                   n_hashes, vir_seq_len, ff_chunks, attn_chunks, mol_seq_len,
                   cmd_args, train_dataset, test_dataset, output_folder,
                   train_batch_size, epochs, validate_every, save_every,
                   checkpoint_id, deepspeed_optimizer, use_full_attn,
                   gradient_accumulation_steps, filter_thres):
    results = {
        'generated_seq': [],
        'generated_mol': [],
        'target_mol': [],
        'input_genome': []
    }

    encoder = ReformerLM(
        num_tokens=input_lang.n_words,
        dim=dim,
        bucket_size=bucket_size,
        depth=depth,
        heads=heads,
        n_hashes=n_hashes,
        max_seq_len=vir_seq_len,
        ff_chunks=ff_chunks,
        attn_chunks=attn_chunks,
        weight_tie=True,
        weight_tie_embedding=True,
        axial_position_emb=True,
        axial_position_shape=compute_axial_position_shape(vir_seq_len),
        axial_position_dims=(dim // 2, dim // 2),
        return_embeddings=True,
        use_full_attn=use_full_attn).to(device)

    decoder = ReformerLM(
        num_tokens=target_lang.n_words,
        dim=dim,
        bucket_size=bucket_size,
        depth=depth,
        heads=heads,
        n_hashes=n_hashes,
        ff_chunks=ff_chunks,
        attn_chunks=attn_chunks,
        max_seq_len=mol_seq_len,
        axial_position_emb=True,
        axial_position_shape=compute_axial_position_shape(mol_seq_len),
        axial_position_dims=(dim // 2, dim // 2),
        weight_tie=True,
        weight_tie_embedding=True,
        causal=True,
        use_full_attn=use_full_attn).to(device)

    SAVE_DIR = os.sep.join([output_folder, 'saved_model'])

    if checkpoint_id:
        enc_ckp_max = checkpoint_id
        dec_ckp_max = checkpoint_id
    else:
        try:
            enc_ckp_max = np.max([
                int(ckp)
                for ckp in os.listdir(os.sep.join([SAVE_DIR, 'encoder']))
            ])
        except Exception as e:
            print('Exception:', e)
            enc_ckp_max = 0

        try:
            dec_ckp_max = np.max([
                int(ckp)
                for ckp in os.listdir(os.sep.join([SAVE_DIR, 'decoder']))
            ])
        except:
            dec_ckp_max = 0

    encoder = TrainingWrapper(encoder, ignore_index=PAD_IDX,
                              pad_value=PAD_IDX).to(device)
    decoder = TrainingWrapper(decoder, ignore_index=PAD_IDX,
                              pad_value=PAD_IDX).to(device)
    '''
    encoder_params = filter(lambda p: p.requires_grad, encoder.parameters())
    decoder_params = filter(lambda p: p.requires_grad, decoder.parameters())

    if deepspeed_optimizer == False:
        print('No DeepSpeed optimizer found. Using RangerLars.')
        encoder_optimizer = RangerLars(encoder.parameters())
        decoder_optimizer = RangerLars(decoder.parameters())

        encoder_engine, encoder_optimizer, trainloader, _ = deepspeed.initialize(
            args=cmd_args,
            model=encoder,
            optimizer=encoder_optimizer,
            model_parameters=encoder_params,
            training_data=train_dataset,
            dist_init_required=True
            )

        decoder_engine, decoder_optimizer, testloader, _ = deepspeed.initialize(
            args=cmd_args,
            model=decoder,
            optimizer=decoder_optimizer,
            model_parameters=decoder_params,
            training_data=test_dataset,
            dist_init_required=False
            )
    else:
        print('Found optimizer in the DeepSpeed configurations. Using it.')
        encoder_engine, encoder_optimizer, trainloader, _ = deepspeed.initialize(args=cmd_args, model=encoder, model_parameters=encoder_params, training_data=train_dataset, dist_init_required=True)
        decoder_engine, decoder_optimizer, testloader, _ = deepspeed.initialize(args=cmd_args, model=decoder, model_parameters=decoder_params, training_data=test_dataset, dist_init_required=False)

    _, encoder_client_sd = encoder_engine.load_checkpoint(os.sep.join([SAVE_DIR,'encoder']), enc_ckp_max)
    _, decoder_client_sd = decoder_engine.load_checkpoint(os.sep.join([SAVE_DIR,'decoder']), dec_ckp_max)

    gpus_mini_batch = (train_batch_size// gradient_accumulation_steps) // torch.cuda.device_count()
    print('gpus_mini_batch:', gpus_mini_batch, 'with gradient_accumulation_steps:', gradient_accumulation_steps)

    for pair in tqdm(testloader):
        encoder_engine.eval()
        decoder_engine.eval()
        encoder.eval()
        decoder.eval()
        with torch.no_grad():
            ts_src = pair[0]
            ts_trg = pair[1]

            input_genome = [[input_lang.index2word[gen_idx.item()] for gen_idx in smpl] for smpl in pair[0]]
            target_mol = [[target_lang.index2word[mol_idx.item()] for mol_idx in smpl] for smpl in pair[1]]

            ts_src = ts_src.to(encoder_engine.local_rank) #ts_src.to(device) #
            ts_trg = ts_trg.to(decoder_engine.local_rank) #ts_trg.to(device) #

            print('ts_src.shape', ts_src.shape)
            print('ts_src.shape', ts_trg.shape)

            enc_keys = encoder(ts_src) #encoder_engine(ts_src)
            yi = torch.tensor([[SOS_token] for _ in range(gpus_mini_batch)]).long().to(decoder_engine.local_rank) #to(device) #

            #sample = decoder_engine.generate(yi, mol_seq_len, filter_logits_fn=top_p, filter_thres=0.95, keys=enc_keys, eos_token = EOS_token)
            sample = decoder.generate(yi, mol_seq_len, filter_logits_fn=top_p, filter_thres=0.95, keys=enc_keys, eos_token = EOS_token)
            actual_mol = []
            for mol_seq in sample.cpu().numpy():
                for mol_idx in mol_seq:
                    actual_mol.append(target_lang.index2word[mol_idx])
                print('Generated Seq:', sample)
                print('Generated Mol:', actual_mol)
                print('Real Mol:', target_mol[:target_mol.index(target_lang.index2word[EOS_token])])

                results['generated_seq'].append(sample)
                results['generated_mol'].append(actual_mol)
                results['target_mol'].append(target_mol)
                results['input_genome'].append(input_genome)

    print('Saving Test Results..')
    pickle.dump(results, open(os.sep.join([output_folder,'test_results.pkl']), 'wb'))
    '''

    encoder_checkpoint = os.sep.join([
        output_folder, 'saved_model', 'encoder', enc_ckp_max,
        'mp_rank_00_model_states.pt'
    ])
    decoder_checkpoint = os.sep.join([
        output_folder, 'saved_model', 'decoder', dec_ckp_max,
        'mp_rank_00_model_states.pt'
    ])

    encoder.load_state_dict(
        torch.load(encoder_checkpoint,
                   map_location=torch.device(device))['module'])
    decoder.load_state_dict(
        torch.load(decoder_checkpoint,
                   map_location=torch.device(device))['module'])

    real_batch_size = train_batch_size // gradient_accumulation_steps
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=real_batch_size,
                             shuffle=True)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        encoder = nn.DataParallel(encoder)
        decoder = nn.DataParallel(decoder)

    encoder.to(device)
    decoder.to(device)

    for pair in tqdm(test_loader):
        encoder.eval()
        decoder.eval()
        with torch.no_grad():
            ts_src = torch.tensor(np.array([pair[0].numpy()])).to(device)
            ts_trg = torch.tensor(np.array([pair[1].numpy()])).to(device)

            input_genome = [
                input_lang.index2word[gen_idx.item()] for gen_idx in pair[0]
            ]
            target_mol = [
                target_lang.index2word[mol_idx.item()] for mol_idx in pair[1]
            ]

            enc_keys = encoder(ts_src)
            yi = torch.tensor([[SOS_token]]).long().to(device)

            sample = decoder.generate(yi,
                                      mol_seq_len,
                                      filter_logits_fn=top_p,
                                      filter_thres=filter_thres,
                                      keys=enc_keys,
                                      eos_token=EOS_token)
            actual_mol = []
            for mol_seq in sample.cpu().numpy():
                for mol_idx in mol_seq:
                    actual_mol.append(target_lang.index2word[mol_idx])
                print('Generated Seq:', sample)
                print('Generated Mol:', actual_mol)
                print(
                    'Real Mol:',
                    target_mol[:target_mol.index(target_lang.
                                                 index2word[EOS_token])])

                results['generated_seq'].append(sample)
                results['generated_mol'].append(actual_mol)
                results['target_mol'].append(target_mol)
                results['input_genome'].append(input_genome)

    print('Saving Test Results..')
    pickle.dump(results,
                open(os.sep.join([output_folder, 'test_results.pkl']), 'wb'))
    '''
Exemplo n.º 11
0
def train(device='cpu',
          output_dir='model',
          epochs=5,
          save_step=5,
          batch_size=4):

    model = ReformerLM(num_tokens=13137,
                       dim=128,
                       depth=12,
                       max_seq_len=4096,
                       lsh_dropout=0.1,
                       causal=True,
                       full_attn_thres=128)
    model = TrainingWrapper(model, ignore_index=0, pad_value=0).to(device)
    # output_dir="model"
    model_cpu_path = os.path.join(output_dir, 'model_cpu.pt')
    try:
        model.load_state_dict(torch.load(model_cpu_path))
    except:
        pass

    model.train()
    optimizer = AdamW(params=model.parameters())
    optimizer_path = os.path.join(output_dir, 'optimizer.pt')
    try:
        optimizer.load_state_dict(torch.load(optimizer_path))
    except:
        pass
    print(optimizer)
    total_loss = 0.0
    # batch_size=4

    loss = []

    data = []
    for it in get_data("data/train.json", tokenizer):
        data.append(it)
    # data=data[:1000]
    loss_fn = nn.CrossEntropyLoss()  # -100 index = padding token
    for n in tqdm(range(epochs)):
        # print(n)
        random.shuffle(data)
        inputs = []
        labels = []
        for i, it in enumerate(data):
            # print("it",it)
            inputs.append(it['keywords'])
            labels.append(it['text'])
            if i % batch_size == 0 and i != 0:
                # print(it)

                inputs_batch = torch.tensor(inputs).long().to(device)

                labels_batch = torch.tensor(labels).long().to(device)
                # print(inputs_batch)
                inputs = []
                labels = []

                # inputs = torch.tensor(it['keywords']).long()
                # labels = torch.tensor(it['text']).long()
                # print("inputs",inputs)
                pred = model(inputs_batch)
                mlm_loss = loss_fn(pred.view(-1, tokenizer.vocab_size),
                                   labels_batch.view(-1))

                total_loss += mlm_loss.item()
                loss.append(mlm_loss.item())
                print('loss', mlm_loss.item())
                mlm_loss.backward()
                optimizer.step()
                model.zero_grad()
                # output_dir="model"
            if i % save_step == 0 and i != 0:
                model_cpu_path = os.path.join(output_dir, 'model_cpu.pt')
                optimizer_path = os.path.join(output_dir, 'optimizer.pt')
                torch.save(model.state_dict(), model_cpu_path)
                torch.save(optimizer.state_dict(), optimizer_path)
        model_cpu_path = os.path.join(output_dir, 'model_cpu.pt')
        optimizer_path = os.path.join(output_dir, 'optimizer.pt')
        torch.save(model.state_dict(), model_cpu_path)
        torch.save(optimizer.state_dict(), optimizer_path)