Beispiel #1
0
def pretrain(restore_from=None):
    """Trains the Prior RNN"""

    # Read vocabulary from a file
    voc = Vocabulary(init_from_file="data/Voc")

    # Create a Dataset from a SMILES file
    moldata = MolData("data/mols_filtered.smi", voc)
    data = DataLoader(moldata,
                      batch_size=128,
                      shuffle=True,
                      drop_last=True,
                      collate_fn=MolData.collate_fn)

    Prior = RNN(voc)

    # Can restore from a saved RNN
    if restore_from:
        Prior.rnn.load_state_dict(torch.load(restore_from))

    optimizer = torch.optim.Adam(Prior.rnn.parameters(), lr=0.001)
    for epoch in range(1, 6):
        # When training on a few million compounds, this model converges
        # in a few of epochs or even faster. If model sized is increased
        # its probably a good idea to check loss against an external set of
        # validation SMILES to make sure we dont overfit too much.
        for step, batch in tqdm(enumerate(data), total=len(data)):

            # Sample from DataLoader
            seqs = batch.long()

            # Calculate loss
            log_p, _ = Prior.likelihood(seqs)
            loss = -log_p.mean()

            # Calculate gradients and take a step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Every 500 steps we decrease learning rate and print some information
            if step % 500 == 0 and step != 0 and False:
                decrease_learning_rate(optimizer, decrease_by=0.03)
                tqdm.write("*" * 50)
                tqdm.write(
                    "Epoch {:3d}   step {:3d}    loss: {:5.2f}\n".format(
                        epoch, step, loss.data[0]))
                seqs, likelihood, _ = Prior.sample(128)
                valid = 0
                for i, seq in enumerate(seqs.cpu().numpy()):
                    smile = voc.decode(seq)
                    if Chem.MolFromSmiles(smile):
                        valid += 1
                    if i < 5:
                        tqdm.write(smile)
                tqdm.write("\n{:>4.1f}% valid SMILES".format(100 * valid /
                                                             len(seqs)))
                tqdm.write("*" * 50 + "\n")
Beispiel #2
0
def train_model():
    """Do transfer learning for generating SMILES"""
    voc = Vocabulary(init_from_file='data/Voc')
    cano_smi_file('refined_smii.csv', 'refined_smii_cano.csv')
    moldata = MolData('refined_smii_cano.csv', voc)
    # Monomers 67 and 180 were removed because of the unseen [C-] in voc
    # DAs containing [se] [SiH2] [n] removed: 38 molecules
    data = DataLoader(moldata,
                      batch_size=64,
                      shuffle=True,
                      drop_last=False,
                      collate_fn=MolData.collate_fn)
    transfer_model = RNN(voc)

    if torch.cuda.is_available():
        transfer_model.rnn.load_state_dict(torch.load('data/Prior.ckpt'))
    else:
        transfer_model.rnn.load_state_dict(
            torch.load('data/Prior.ckpt',
                       map_location=lambda storage, loc: storage))

    # for param in transfer_model.rnn.parameters():
    #     param.requires_grad = False
    optimizer = torch.optim.Adam(transfer_model.rnn.parameters(), lr=0.001)

    for epoch in range(1, 10):

        for step, batch in tqdm(enumerate(data), total=len(data)):
            seqs = batch.long()
            log_p, _ = transfer_model.likelihood(seqs)
            loss = -log_p.mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if step % 5 == 0 and step != 0:
                decrease_learning_rate(optimizer, decrease_by=0.03)
                tqdm.write('*' * 50)
                tqdm.write(
                    "Epoch {:3d}   step {:3d}    loss: {:5.2f}\n".format(
                        epoch, step, loss.data[0]))
                seqs, likelihood, _ = transfer_model.sample(128)
                valid = 0
                for i, seq in enumerate(seqs.cpu().numpy()):
                    smile = voc.decode(seq)
                    if Chem.MolFromSmiles(smile):
                        valid += 1
                    if i < 5:
                        tqdm.write(smile)
                tqdm.write("\n{:>4.1f}% valid SMILES".format(100 * valid /
                                                             len(seqs)))
                tqdm.write("*" * 50 + '\n')
                torch.save(transfer_model.rnn.state_dict(),
                           "data/transfer_model2.ckpt")

        torch.save(transfer_model.rnn.state_dict(),
                   "data/transfer_modelw.ckpt")
def train_model(voc_dir,
                smi_dir,
                prior_dir,
                tf_dir,
                tf_process_dir,
                freeze=False):
    """
    Transfer learning on target molecules using the SMILES structures
    Args:
        voc_dir: location of the vocabulary
        smi_dir: location of the SMILES file used for transfer learning
        prior_dir: location of prior trained model to initialize transfer learning
        tf_dir: location to save the transfer learning model
        tf_process_dir: location to save the SMILES sampled while doing transfer learning
        freeze: Bool. If true, all parameters in the RNN will be frozen except for the last linear layer during
        transfer learning.

    Returns: None

    """
    voc = Vocabulary(init_from_file=voc_dir)
    #cano_smi_file('all_smi_refined.csv', 'all_smi_refined_cano.csv')
    moldata = MolData(smi_dir, voc)
    # Monomers 67 and 180 were removed because of the unseen [C-] in voc
    # DAs containing [C] removed: 43 molecules in 5356; Ge removed: 154 in 5356; [c] removed 4 in 5356
    # [S] 1 molecule in 5356
    data = DataLoader(moldata,
                      batch_size=64,
                      shuffle=True,
                      drop_last=False,
                      collate_fn=MolData.collate_fn)
    transfer_model = RNN(voc)
    # if freeze=True, freeze all parameters except those in the linear layer
    if freeze:
        for param in transfer_model.rnn.parameters():
            param.requires_grad = False
        transfer_model.rnn.linear = nn.Linear(512, voc.vocab_size)
    if torch.cuda.is_available():
        transfer_model.rnn.load_state_dict(torch.load(prior_dir))
    else:
        transfer_model.rnn.load_state_dict(
            torch.load(prior_dir, map_location=lambda storage, loc: storage))

    optimizer = torch.optim.Adam(transfer_model.rnn.parameters(), lr=0.0005)

    smi_lst = []
    epoch_lst = []
    for epoch in range(1, 11):

        for step, batch in tqdm(enumerate(data), total=len(data)):
            seqs = batch.long()
            log_p, _ = transfer_model.likelihood(seqs)
            loss = -log_p.mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if step % 80 == 0 and step != 0:
                decrease_learning_rate(optimizer, decrease_by=0.03)
                tqdm.write('*' * 50)
                tqdm.write(
                    "Epoch {:3d}   step {:3d}    loss: {:5.2f}\n".format(
                        epoch, step, loss.data[0]))
                seqs, likelihood, _ = transfer_model.sample(128)
                valid = 0
                for i, seq in enumerate(seqs.cpu().numpy()):
                    smile = voc.decode(seq)
                    if Chem.MolFromSmiles(smile):
                        valid += 1
                    if i < 5:
                        tqdm.write(smile)
                tqdm.write("\n{:>4.1f}% valid SMILES".format(100 * valid /
                                                             len(seqs)))
                tqdm.write("*" * 50 + '\n')
                torch.save(transfer_model.rnn.state_dict(), tf_dir)
        seqs, likelihood, _ = transfer_model.sample(1024)
        valid = 0
        #valid_smis = []
        for i, seq in enumerate(seqs.cpu().numpy()):
            smile = voc.decode(seq)
            if Chem.MolFromSmiles(smile):
                try:
                    AllChem.GetMorganFingerprintAsBitVect(
                        Chem.MolFromSmiles(smile), 2, 1024)
                    valid += 1
                    smi_lst.append(smile)
                    epoch_lst.append(epoch)
                except:
                    continue

        torch.save(transfer_model.rnn.state_dict(), tf_dir)

    transfer_process_df = pd.DataFrame(columns=['SMILES', 'Epoch'])
    transfer_process_df['SMILES'] = pd.Series(data=smi_lst)
    transfer_process_df['Epoch'] = pd.Series(data=epoch_lst)
    transfer_process_df.to_csv(tf_process_dir)
Beispiel #4
0
def pretrain(runname='chembl', restore_from=None):
    """Trains the prior RNN"""

    writer = SummaryWriter('logs/%s' % runname)

    # Read vocabulary from a file
    voc = Vocabulary(init_from_file="data/Voc_%s" % runname)

    # Create a Dataset from a SMILES file
    moldata = MolData("data/mols_%s_filtered.smi" % runname, voc)
    data = DataLoader(moldata, batch_size=128, shuffle=True, drop_last=True, collate_fn=MolData.collate_fn)

    prior = RNN(voc)
    # writer.add_graph(prior.rnn, data.dataset[0])

    # Can restore from a saved RNN
    if restore_from:
        prior.rnn.load_state_dict(torch.load(restore_from))

    optimizer = torch.optim.Adam(prior.rnn.parameters(), lr=0.001)

    running_loss = 0.0
    for epoch in range(1, 6):
        # When training on a few million compounds, this model converges
        # in a few of epochs or even faster. If model sized is increased
        # its probably a good idea to check loss against an external set of
        # validation SMILES to make sure we don't overfit too much.
        for step, batch in tqdm(enumerate(data), total=len(data)):

            # Sample from DataLoader
            seqs = batch.long()

            # Calculate loss
            log_p, entropy = prior.likelihood(seqs)
            loss = -log_p.mean()
            running_loss += loss.item()

            # Calculate gradients and take a step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Every 250 steps we decrease learning rate and print some information
            if step % 250 == 249 and step != 0:
                decrease_learning_rate(optimizer, decrease_by=0.03)
                seqs, likelihood, _ = prior.sample(128)
                valid = 0
                smiles = list()
                for i, seq in enumerate(seqs.cpu().numpy()):
                    smile = voc.decode(seq)
                    if Chem.MolFromSmiles(smile):
                        valid += 1
                        if valid < 5:
                            tqdm.write(smile)
                            smiles.append(smile.strip())

                tqdm.write("*" * 50)
                tqdm.write("Epoch {:3d}   step {:3d}    loss: {:5.2f}\n".format(epoch, step, running_loss / step))
                tqdm.write("\n{:>4.1f}% valid SMILES".format(100 * valid / len(seqs)))
                tqdm.write("*" * 50 + "\n")
                torch.save(prior.rnn.state_dict(), "data/prior_%s.ckpt" % runname)
                writer.add_scalar('training loss', running_loss / 250, epoch * len(data) + step)
                writer.add_scalar('valid_smiles', 100 * valid / len(seqs), epoch * len(data) + step)
                writer.add_image('sampled_mols', mol_to_torchimage(smiles))
                running_loss = 0.0

        # Save the prior and close writer
        torch.save(prior.rnn.state_dict(), "data/prior_%s.ckpt" % runname)
        writer.close()
def main(restore_from=None, visualize=False):
    # read vocbulary from a file
    voc = Vocabulary(init_from_file="data/voc")

    # create a dataset from a smiles file
    moldata = MolData("data/mols_filtered.smi", voc)
    data = DataLoader(moldata,
                      batch_size=10,
                      shuffle=True,
                      drop_last=True,
                      collate_fn=MolData.collate_fn)

    agent = RNN(voc)

    # can restore from a saved RNN
    if restore_from:
        agent.rnn.load_state_dict(
            torch.load(restore_from, map_location=torch.device('cpu')))

    optimizer = torch.optim.Adam(agent.rnn.parameters(), lr=0.001)
    torch.autograd.set_detect_anomaly(True)
    valid_ratios = list()
    for epoch in range(1, 2):
        for step, batch in tqdm(enumerate(data), total=len(data)):
            # sample from DataLoader
            seqs = batch.long()

            # calculate loss
            log_p, _ = agent.likelihood(seqs)
            loss = -log_p.mean()
            # print(loss)

            # calculate gradients and take a step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # every n steps we decrease learning rate and print out some information, n can be customized
            if step % 5 == 0 and step != 0:
                decrease_learning_rate(optimizer, decrease_by=0.03)
                tqdm.write("#" * 50)
                tqdm.write("Epoch {:3d} step {:3d} loss: {:5.2f}\n".format(
                    epoch, step, loss.data))
                seqs, likelihood, _ = agent.sample(128)
                valid = 0
                for i, seq in enumerate(seqs.cpu().numpy()):
                    smile = voc.decode(seq)
                    if Chem.MolFromSmiles(smile):
                        valid += 1
                    if i < 5:
                        tqdm.write(smile)
                    valid_ratio = 100 * valid / len(seqs)
                    valid_ratios.append(valid_ratio)
                    tqdm.write("\n{:>4.1f}% valid SMILES".format(100 * valid /
                                                                 len(seqs)))
                    tqdm.write("#" * 50 + "\n")
                    torch.save(agent.rnn.state_dict(), "data/Prior.ckpt")
        torch.save(agent.rnn.state_dict(), "data/Prior.ckpt")
    if visualize:
        plt.plot(range(len(valid_ratios)),
                 valid_ratios,
                 color='red',
                 linewidth=5)
        plt.savefig('/Users/ruiminma/Desktop/validratio.png',
                    bbox_inches='tight',
                    dpi=400)
Beispiel #6
0
def pretrain(restore_from=None,
             save_to="data/Prior.ckpt",
             data="data/mols_filtered.smi",
             voc_file="data/Voc",
             batch_size=128,
             learning_rate=0.001,
             n_epochs=5,
             store_loss_dir=None,
             embedding_size=32):
    """Trains the Prior RNN"""

    # Read vocabulary from a file
    voc = Vocabulary(init_from_file=voc_file)

    # Create a Dataset from a SMILES file
    moldata = MolData(data, voc)
    data = DataLoader(moldata,
                      batch_size=batch_size,
                      shuffle=True,
                      drop_last=True,
                      collate_fn=MolData.collate_fn)

    Prior = RNN(voc, embedding_size)

    # Adding a file to log loss info
    if store_loss_dir is None:
        out_f = open("loss.csv", "w")
    else:
        out_f = open("{}/loss.csv".format(store_loss_dir.rstrip("/")), "w")

    out_f.write("Step,Loss\n")

    # Can restore from a saved RNN
    if restore_from:
        Prior.rnn.load_state_dict(torch.load(restore_from))

    # For later plotting the loss
    training_step_counter = 0
    n_logging = 100

    optimizer = torch.optim.Adam(Prior.rnn.parameters(), lr=learning_rate)
    for epoch in range(1, n_epochs + 1):
        # When training on a few million compounds, this model converges
        # in a few of epochs or even faster. If model sized is increased
        # its probably a good idea to check loss against an external set of
        # validation SMILES to make sure we dont overfit too much.
        for step, batch in tqdm(enumerate(data), total=len(data)):

            # Sample from DataLoader
            seqs = batch.long()

            # Calculate loss
            log_p, _ = Prior.likelihood(seqs)
            loss = -log_p.mean()

            # Calculate gradients and take a step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Logging the loss to a file
            if training_step_counter % n_logging == 0:
                out_f.write("{},{}\n".format(step, loss))
            training_step_counter += 1

            # Every 500 steps we decrease learning rate and print some information
            if step % 500 == 0 and step != 0:
                decrease_learning_rate(optimizer, decrease_by=0.03)
                tqdm.write("*" * 50)
                tqdm.write(
                    "Epoch {:3d}   step {:3d}    loss: {:5.2f}\n".format(
                        epoch, step, loss.data))
                seqs, likelihood, _ = Prior.sample(128)
                valid = 0
                for i, seq in enumerate(seqs.cpu().numpy()):
                    smile = voc.decode(seq)
                    if Chem.MolFromSmiles(smile):
                        valid += 1
                    if i < 5:
                        tqdm.write(smile)
                tqdm.write("\n{:>4.1f}% valid SMILES".format(100 * valid /
                                                             len(seqs)))
                tqdm.write("*" * 50 + "\n")
                torch.save(Prior.rnn.state_dict(), save_to)

        # Save the Prior
        torch.save(Prior.rnn.state_dict(), save_to)

    f_out.close()
    def fit(voc_path, mol_path, restore_path, max_save_path, last_save_path,
            epoch_num, step_num, decay_step_num, smile_num, lr, weigth_decay):

        restore_from = restore_path  # if not restore model print None
        # Read vocabulary from a file
        voc = Vocabulary(init_from_file=voc_path)

        # Create a Dataset from a SMILES file
        moldata = MolData(mol_path, voc)
        data = DataLoader(moldata,
                          batch_size=128,
                          shuffle=True,
                          drop_last=True,
                          collate_fn=MolData.collate_fn)

        Prior = RNN(voc)

        # Can restore from a saved RNN
        if restore_from:
            Prior.rnn.load_state_dict(torch.load(restore_from))

        total_loss = []
        total_valid = []
        max_valid_pro = 0

        optimizer = torch.optim.Adam(Prior.rnn.parameters(), lr=lr)

        for epoch in range(1, epoch_num):
            for step, batch in tqdm(enumerate(data), total=len(data)):

                # Sample from DataLoader
                seqs = batch.long()

                # Calculate loss
                log_p, _ = Prior.likelihood(seqs)
                loss = -log_p.mean()

                # Calculate gradients and take a step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Every 300 steps we decrease learning rate and print some information
                if step != 0 and step % decay_step_num == 0:
                    decrease_learning_rate(optimizer, decrease_by=weigth_decay)
                if step % step_num == 0:
                    tqdm.write("*" * 50)
                    tqdm.write(
                        "Epoch {:3d}   step {:3d}    loss: {:5.2f}\n".format(
                            epoch, step, loss))
                    #                    print("Epoch {:3d}   step {:3d}    loss: {:5.2f}\n".format(epoch, step, loss))
                    total_loss.append(float(loss))
                    seqs, likelihood, _ = Prior.sample(128)
                    valid = 0
                    #                    smiles=[]
                    #                    vali_smi=[]
                    for i, seq in enumerate(seqs.cpu().numpy()):
                        smile = voc.decode(seq)
                        #                        smiles.append(smile)
                        if Chem.MolFromSmiles(smile):
                            valid += 1


#                            vali_smi.append(smile)
                        if i < smile_num:
                            print(smile)
                    vali_pro = valid / len(seqs)
                    total_valid.append(float(vali_pro))
                    tqdm.write("\n{:>4.1f}% valid SMILES".format(100 * valid /
                                                                 len(seqs)))
                    tqdm.write("*" * 50 + "\n")

                    if vali_pro > max_valid_pro:
                        max_valid_pro = vali_pro
                        torch.save(Prior.rnn.state_dict(), max_save_path)

            # Save the Prior
            torch.save(Prior.rnn.state_dict(), last_save_path)

        print("total loss:", total_loss)
        print("total valid:", total_valid)
        return total_loss, total_valid
Beispiel #8
0
def train(model: models.reinvent.Model,
          smiles_list: List[str],
          model_path: str,
          epochs=10,
          lr=0.001,
          patience=30000,
          batch_size=128,
          steps_to_change_lr=500,
          lr_change=0.01,
          save_each_epoch=False,
          temperature=1.0):
    """
    Trains a model
    :param model: the model to train
    :param smiles_list: a list of SMILES to train on
    :param model_path: path where to save the model
    :param epochs: number of epochs to train
    :param lr: Learning rate for the optimizer
    :param patience: number of steps until the early stop kicks in and interrupts the training
    :param batch_size: Batch size of the model
    :param temperature: Factor by which which the logits are dived. Small numbers make the model more confident on each
                        position, but also more conservative. Large values result in random predictions at each step.
    :return:
    """
    # Create a Dataset from a SMILES file
    moldata = Dataset.for_model(smiles_list, model)

    print("batch size: {}\n".format(batch_size))
    data = DataLoader(moldata,
                      batch_size=batch_size,
                      shuffle=False,
                      drop_last=True,
                      collate_fn=Dataset.collate_fn)

    # we stop early if the loss does not change significantly anymore
    lowest_loss = np.float("inf")
    eps = 0.01
    overall_patience = patience
    patience = overall_patience

    optimizer = torch.optim.Adam(model.rnn.parameters(), lr=lr)
    for epoch in range(epochs):
        logging.info("Start Epoch {}".format(epoch))
        for step, batch in tqdm(enumerate(data), total=len(data)):

            # Sample from DataLoader
            seqs = batch.long()

            # Calculate loss
            log_p, _ = model.likelihood(seqs, temperature=temperature)
            loss = -log_p.mean()
            if loss.item() + eps < lowest_loss:
                patience = overall_patience
                lowest_loss = loss.item()
            else:
                patience -= 1

            if patience == 0:
                tqdm.write(
                    "*************Epoch {:2d}****************".format(epoch))
                tqdm.write(
                    "*** NO LOSS IMPROVEMENT AT STEP {:3d} ***".format(step))
                tqdm.write("*************EARLY  STOP****************")
                # Save the Prior
                save_model(model, model_path, epoch, save_each_epoch)
                return

            # Calculate gradients and take a step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Every 500 steps we decrease learning rate and print some information
            if step % (steps_to_change_lr //
                       max(1, torch.cuda.device_count())) == 0 and step != 0:
                decrease_learning_rate(optimizer, decrease_by=lr_change)
                tqdm.write(("Epoch {:3d}   step {:3d}    loss: {:5.2f}    "
                            "patience: {}    lr: {}").format(
                                epoch, step, loss.data[0], patience,
                                optimizer.param_groups[0]["lr"]))
                seqs, likelihood, _ = model.sample(128,
                                                   temperature=temperature)
                valid = 0
                tqdm.write(
                    "\n\n*************Epoch {:2d}****************".format(
                        epoch))
                smiles = model.sequence_to_smiles(seqs)
                for i, smile in enumerate(smiles):
                    if Chem.MolFromSmiles(smile):
                        valid += 1
                    if i < 5:
                        tqdm.write(smile)
                tqdm.write("\n{:>4.1f}% valid SMILES".format(100 * valid /
                                                             len(seqs)))
                tqdm.write("****************************************\n")

        # Save the model after each epoch
        save_model(model, model_path, epoch, save_each_epoch)