Example #1
0
def gen(text):
    model = ReformerLM(num_tokens=13137,
                       dim=128,
                       depth=12,
                       max_seq_len=4096,
                       lsh_dropout=0.1,
                       causal=True,
                       full_attn_thres=128)
    model = TrainingWrapper(model, ignore_index=0, pad_value=0).cpu()
    output_dir = "model"
    model_cpu_path = os.path.join(output_dir, 'model_cpu.pt')
    model.load_state_dict(torch.load(model_cpu_path))
    initial = auto_encode(text)
    #   print(initial)
    sample = model.generate(
        initial, 10, temperature=1., filter_thres=0.9, eos_token=1
    )  # assume end token is 1, or omit and it will sample up to 100
    #   print(sample)
    # print(sample.shape) # (1, <=100) token ids
    text = tokenizer.convert_ids_to_tokens(sample.tolist()[0])
    print(text)
Example #2
0
    model=model,
    model_parameters=model.parameters(),
    training_data=train_dataset)

# training

for i, data in enumerate(trainloader):
    model_engine.train()
    data = data.to(model_engine.local_rank)
    loss = model_engine(data, return_loss=True)
    model_engine.backward(loss)
    model_engine.step()
    print(loss.item())

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            inp = random.choice(val_dataset)[:-1]
            loss = model(inp[None, :].cuda(), return_loss=True)
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp.cuda(), GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)
Example #3
0
# training

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader), return_loss=True)
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader), return_loss=True)
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)
Example #4
0
                   dim=1024,
                   depth=12,
                   max_seq_len=4096,
                   lsh_dropout=0.1,
                   causal=True,
                   full_attn_thres=1024)

# 0 is used for padding and no loss to be calculated on it
model = TrainingWrapper(model, ignore_index=0, pad_value=0)

# the wrapper can handle evenly packed sequences
x_train = randint(0, 20000, (3, 357))

# or if you have a list of uneven sequences, it will be padded for you
x_train = [
    randint(0, 20000, (120, )),
    randint(0, 20000, (253, )),
    randint(0, 20000, (846, ))
]

# when training, set return_loss equal to True
model.train()
loss = model(x_train, return_loss=True)
loss.backward()

# when evaluating, just use the generate function, which will default to top_k sampling with temperature of 1.
initial = torch.tensor([[0]]).long()  # assume 0 is start token
sample = model.generate(
    initial, 100, temperature=1., filter_thres=0.9,
    eos_token=1)  # assume end token is 1, or omit and it will sample up to 100
print(sample.shape)  # (1, <=100) token ids
Example #5
0
def test_encdec_v1(input_lang, target_lang, dim, bucket_size, depth, heads,
                   n_hashes, vir_seq_len, ff_chunks, attn_chunks, mol_seq_len,
                   cmd_args, train_dataset, test_dataset, output_folder,
                   train_batch_size, epochs, validate_every, save_every,
                   checkpoint_id, deepspeed_optimizer, use_full_attn,
                   gradient_accumulation_steps, filter_thres):
    results = {
        'generated_seq': [],
        'generated_mol': [],
        'target_mol': [],
        'input_genome': []
    }

    encoder = ReformerLM(
        num_tokens=input_lang.n_words,
        dim=dim,
        bucket_size=bucket_size,
        depth=depth,
        heads=heads,
        n_hashes=n_hashes,
        max_seq_len=vir_seq_len,
        ff_chunks=ff_chunks,
        attn_chunks=attn_chunks,
        weight_tie=True,
        weight_tie_embedding=True,
        axial_position_emb=True,
        axial_position_shape=compute_axial_position_shape(vir_seq_len),
        axial_position_dims=(dim // 2, dim // 2),
        return_embeddings=True,
        use_full_attn=use_full_attn).to(device)

    decoder = ReformerLM(
        num_tokens=target_lang.n_words,
        dim=dim,
        bucket_size=bucket_size,
        depth=depth,
        heads=heads,
        n_hashes=n_hashes,
        ff_chunks=ff_chunks,
        attn_chunks=attn_chunks,
        max_seq_len=mol_seq_len,
        axial_position_emb=True,
        axial_position_shape=compute_axial_position_shape(mol_seq_len),
        axial_position_dims=(dim // 2, dim // 2),
        weight_tie=True,
        weight_tie_embedding=True,
        causal=True,
        use_full_attn=use_full_attn).to(device)

    SAVE_DIR = os.sep.join([output_folder, 'saved_model'])

    if checkpoint_id:
        enc_ckp_max = checkpoint_id
        dec_ckp_max = checkpoint_id
    else:
        try:
            enc_ckp_max = np.max([
                int(ckp)
                for ckp in os.listdir(os.sep.join([SAVE_DIR, 'encoder']))
            ])
        except Exception as e:
            print('Exception:', e)
            enc_ckp_max = 0

        try:
            dec_ckp_max = np.max([
                int(ckp)
                for ckp in os.listdir(os.sep.join([SAVE_DIR, 'decoder']))
            ])
        except:
            dec_ckp_max = 0

    encoder = TrainingWrapper(encoder, ignore_index=PAD_IDX,
                              pad_value=PAD_IDX).to(device)
    decoder = TrainingWrapper(decoder, ignore_index=PAD_IDX,
                              pad_value=PAD_IDX).to(device)
    '''
    encoder_params = filter(lambda p: p.requires_grad, encoder.parameters())
    decoder_params = filter(lambda p: p.requires_grad, decoder.parameters())

    if deepspeed_optimizer == False:
        print('No DeepSpeed optimizer found. Using RangerLars.')
        encoder_optimizer = RangerLars(encoder.parameters())
        decoder_optimizer = RangerLars(decoder.parameters())

        encoder_engine, encoder_optimizer, trainloader, _ = deepspeed.initialize(
            args=cmd_args,
            model=encoder,
            optimizer=encoder_optimizer,
            model_parameters=encoder_params,
            training_data=train_dataset,
            dist_init_required=True
            )

        decoder_engine, decoder_optimizer, testloader, _ = deepspeed.initialize(
            args=cmd_args,
            model=decoder,
            optimizer=decoder_optimizer,
            model_parameters=decoder_params,
            training_data=test_dataset,
            dist_init_required=False
            )
    else:
        print('Found optimizer in the DeepSpeed configurations. Using it.')
        encoder_engine, encoder_optimizer, trainloader, _ = deepspeed.initialize(args=cmd_args, model=encoder, model_parameters=encoder_params, training_data=train_dataset, dist_init_required=True)
        decoder_engine, decoder_optimizer, testloader, _ = deepspeed.initialize(args=cmd_args, model=decoder, model_parameters=decoder_params, training_data=test_dataset, dist_init_required=False)

    _, encoder_client_sd = encoder_engine.load_checkpoint(os.sep.join([SAVE_DIR,'encoder']), enc_ckp_max)
    _, decoder_client_sd = decoder_engine.load_checkpoint(os.sep.join([SAVE_DIR,'decoder']), dec_ckp_max)

    gpus_mini_batch = (train_batch_size// gradient_accumulation_steps) // torch.cuda.device_count()
    print('gpus_mini_batch:', gpus_mini_batch, 'with gradient_accumulation_steps:', gradient_accumulation_steps)

    for pair in tqdm(testloader):
        encoder_engine.eval()
        decoder_engine.eval()
        encoder.eval()
        decoder.eval()
        with torch.no_grad():
            ts_src = pair[0]
            ts_trg = pair[1]

            input_genome = [[input_lang.index2word[gen_idx.item()] for gen_idx in smpl] for smpl in pair[0]]
            target_mol = [[target_lang.index2word[mol_idx.item()] for mol_idx in smpl] for smpl in pair[1]]

            ts_src = ts_src.to(encoder_engine.local_rank) #ts_src.to(device) #
            ts_trg = ts_trg.to(decoder_engine.local_rank) #ts_trg.to(device) #

            print('ts_src.shape', ts_src.shape)
            print('ts_src.shape', ts_trg.shape)

            enc_keys = encoder(ts_src) #encoder_engine(ts_src)
            yi = torch.tensor([[SOS_token] for _ in range(gpus_mini_batch)]).long().to(decoder_engine.local_rank) #to(device) #

            #sample = decoder_engine.generate(yi, mol_seq_len, filter_logits_fn=top_p, filter_thres=0.95, keys=enc_keys, eos_token = EOS_token)
            sample = decoder.generate(yi, mol_seq_len, filter_logits_fn=top_p, filter_thres=0.95, keys=enc_keys, eos_token = EOS_token)
            actual_mol = []
            for mol_seq in sample.cpu().numpy():
                for mol_idx in mol_seq:
                    actual_mol.append(target_lang.index2word[mol_idx])
                print('Generated Seq:', sample)
                print('Generated Mol:', actual_mol)
                print('Real Mol:', target_mol[:target_mol.index(target_lang.index2word[EOS_token])])

                results['generated_seq'].append(sample)
                results['generated_mol'].append(actual_mol)
                results['target_mol'].append(target_mol)
                results['input_genome'].append(input_genome)

    print('Saving Test Results..')
    pickle.dump(results, open(os.sep.join([output_folder,'test_results.pkl']), 'wb'))
    '''

    encoder_checkpoint = os.sep.join([
        output_folder, 'saved_model', 'encoder', enc_ckp_max,
        'mp_rank_00_model_states.pt'
    ])
    decoder_checkpoint = os.sep.join([
        output_folder, 'saved_model', 'decoder', dec_ckp_max,
        'mp_rank_00_model_states.pt'
    ])

    encoder.load_state_dict(
        torch.load(encoder_checkpoint,
                   map_location=torch.device(device))['module'])
    decoder.load_state_dict(
        torch.load(decoder_checkpoint,
                   map_location=torch.device(device))['module'])

    real_batch_size = train_batch_size // gradient_accumulation_steps
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=real_batch_size,
                             shuffle=True)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        encoder = nn.DataParallel(encoder)
        decoder = nn.DataParallel(decoder)

    encoder.to(device)
    decoder.to(device)

    for pair in tqdm(test_loader):
        encoder.eval()
        decoder.eval()
        with torch.no_grad():
            ts_src = torch.tensor(np.array([pair[0].numpy()])).to(device)
            ts_trg = torch.tensor(np.array([pair[1].numpy()])).to(device)

            input_genome = [
                input_lang.index2word[gen_idx.item()] for gen_idx in pair[0]
            ]
            target_mol = [
                target_lang.index2word[mol_idx.item()] for mol_idx in pair[1]
            ]

            enc_keys = encoder(ts_src)
            yi = torch.tensor([[SOS_token]]).long().to(device)

            sample = decoder.generate(yi,
                                      mol_seq_len,
                                      filter_logits_fn=top_p,
                                      filter_thres=filter_thres,
                                      keys=enc_keys,
                                      eos_token=EOS_token)
            actual_mol = []
            for mol_seq in sample.cpu().numpy():
                for mol_idx in mol_seq:
                    actual_mol.append(target_lang.index2word[mol_idx])
                print('Generated Seq:', sample)
                print('Generated Mol:', actual_mol)
                print(
                    'Real Mol:',
                    target_mol[:target_mol.index(target_lang.
                                                 index2word[EOS_token])])

                results['generated_seq'].append(sample)
                results['generated_mol'].append(actual_mol)
                results['target_mol'].append(target_mol)
                results['input_genome'].append(input_genome)

    print('Saving Test Results..')
    pickle.dump(results,
                open(os.sep.join([output_folder, 'test_results.pkl']), 'wb'))
    '''